diff --git a/scripts/esim/__init__.py b/scripts/esim/__init__.py deleted file mode 100644 index a17da4f6d7..0000000000 --- a/scripts/esim/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -# coding: utf-8 - -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# pylint: disable=wildcard-import -"""esim example -""" - -from . import esim diff --git a/scripts/esim/esim.py b/scripts/esim/esim.py deleted file mode 100644 index 0ce55f11df..0000000000 --- a/scripts/esim/esim.py +++ /dev/null @@ -1,147 +0,0 @@ -# coding: utf-8 - -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Build an Enhancing LSTM model for Natural Language Inference -""" - -__all__ = ['ESIM'] - -from mxnet.gluon import nn, rnn - - -class ESIM(nn.HybridBlock): - """"Enhanced LSTM for Natural Language Inference" Qian Chen, - Xiaodan Zhu, Zhenhua Ling, Si Wei, Hui Jiang, Diana Inkpen. ACL (2017) - https://arxiv.org/pdf/1609.06038.pdf - - Parameters - ---------- - nwords: int - Number of words in vocab - nword_dims : int - Dimension of word vector - nhiddens : int - Number of hidden units in lstm cell - ndense_units : int - Number of hidden units in dense layer - nclasses : int - Number of categories - drop_out : int - Dropout prob - """ - - def __init__(self, vocab_size, nword_dims, nhidden_units, ndense_units, - nclasses, dropout=0.0, **kwargs): - super(ESIM, self).__init__(**kwargs) - with self.name_scope(): - self.embedding_encoder = nn.Embedding(vocab_size, nword_dims) - self.batch_norm = nn.BatchNorm(axis=-1) - self.lstm_encoder1 = rnn.LSTM(nhidden_units, bidirectional=True) - - self.projection = nn.HybridSequential() - self.projection.add(nn.BatchNorm(axis=-1), - nn.Dropout(dropout), - nn.Dense(nhidden_units, activation='relu', flatten=False)) - - self.lstm_encoder2 = rnn.LSTM(nhidden_units, bidirectional=True) - - self.fc_encoder = nn.HybridSequential() - self.fc_encoder.add(nn.BatchNorm(axis=-1), - nn.Dropout(dropout), - nn.Dense(ndense_units), - nn.ELU(), - nn.BatchNorm(axis=-1), - nn.Dropout(dropout), - nn.Dense(nclasses)) - - self.avg_pool = nn.GlobalAvgPool1D() - self.max_pool = nn.GlobalMaxPool1D() - - def _soft_attention_align(self, F, x1, x2, mask1, mask2): - # x1 shape: (batch, x1_seq_len, nhidden_units*2) - # x2 shape: (batch, x2_seq_len, nhidden_units*2) - # mask1 shape: (batch, x1_seq_len) - # mask2 shape: (batch, x2_seq_len) - # attention shape: (batch, x1_seq_len, x2_seq_len) - attention = F.batch_dot(x1, x2, transpose_b=True) - - # weight1 shape: (batch, x1_seq_len, x2_seq_len) - weight1 = F.softmax(attention + F.expand_dims(mask2, axis=1), axis=-1) - # x1_align shape: (batch, x1_seq_len, nhidden_units*2) - x1_align = F.batch_dot(weight1, x2) - - # weight2 shape: (batch, x1_seq_len, x2_seq_len) - weight2 = F.softmax(attention + F.expand_dims(mask1, axis=2), axis=1) - # x2_align shape: (batch, x2_seq_len, nhidden_units*2) - x2_align = F.batch_dot(weight2, x1, transpose_a=True) - - return x1_align, x2_align - - def _submul(self, F, x1, x2): - mul = F.multiply(x1, x2) - sub = F.subtract(x1, x2) - - return F.concat(mul, sub, dim=-1) - - def _pooling(self, F, x): - # x : NCW C <----> input channels W <----> seq_len - # p1, p2 shape: (batch, input channels) - p1 = F.squeeze(self.avg_pool(x), axis=-1) - p2 = F.squeeze(self.max_pool(x), axis=-1) - - return F.concat(p1, p2, dim=-1) - - def hybrid_forward(self, F, x1, x2, mask1, mask2): # pylint: disable=arguments-differ - # x1, x2 shape: (batch, x1_seq_len), (batch, x2_seq_len) - # mask1, mask2 shape: (batch, x1_seq_len), (batch, x2_seq_len) - # x1_embed shape: (batch, x1_seq_len, nword_dims) - # x2_embed shape: (batch, x2_seq_len, nword_dims) - x1_embed = self.batch_norm(self.embedding_encoder(x1)) - x2_embed = self.batch_norm(self.embedding_encoder(x2)) - - x1_lstm_encode = self.lstm_encoder1(x1_embed) - x2_lstm_encode = self.lstm_encoder1(x2_embed) - - # attention - x1_algin, x2_algin = self._soft_attention_align(F, x1_lstm_encode, x2_lstm_encode, - mask1, mask2) - - # compose - x1_combined = F.concat(x1_lstm_encode, x1_algin, - self._submul(F, x1_lstm_encode, x1_algin), dim=-1) - x2_combined = F.concat(x2_lstm_encode, x2_algin, - self._submul(F, x2_lstm_encode, x2_algin), dim=-1) - - # x1_compose shape: (batch, x1_seq_len, nhidden_units*2) - # x2_compose shape: (batch, x2_seq_len, nhidden_units*2) - x1_compose = self.lstm_encoder2(self.projection(x1_combined)) - x2_compose = self.lstm_encoder2(self.projection(x2_combined)) - - # aggregate - # NWC ------> NCW - x1_compose = F.transpose(x1_compose, axes=(0, 2, 1)) - x2_compose = F.transpose(x2_compose, axes=(0, 2, 1)) - x1_agg = self._pooling(F, x1_compose) - x2_agg = self._pooling(F, x2_compose) - - # fully connection - output = self.fc_encoder(F.concat(x1_agg, x2_agg, dim=-1)) - - return output diff --git a/scripts/natural_language_inference/decomposable_attention.py b/scripts/natural_language_inference/decomposable_attention.py index f20f0f4bfc..f0afbe7441 100644 --- a/scripts/natural_language_inference/decomposable_attention.py +++ b/scripts/natural_language_inference/decomposable_attention.py @@ -27,7 +27,7 @@ from mxnet.gluon import nn -class NLIModel(gluon.HybridBlock): +class DecomposableAttentionModel(gluon.HybridBlock): """ A Decomposable Attention Model for Natural Language Inference using intra-sentence attention. diff --git a/scripts/natural_language_inference/esim.py b/scripts/natural_language_inference/esim.py new file mode 100644 index 0000000000..8f7c99646a --- /dev/null +++ b/scripts/natural_language_inference/esim.py @@ -0,0 +1,117 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Build an Enhancing LSTM model for Natural Language Inference +""" + +__all__ = ['ESIM'] + +from mxnet.gluon import nn, rnn + +EPS = 1e-12 + + +class ESIMModel(nn.HybridBlock): + """"Enhanced LSTM for Natural Language Inference" Qian Chen, + Xiaodan Zhu, Zhenhua Ling, Si Wei, Hui Jiang, Diana Inkpen. ACL (2017) + + Parameters + ---------- + vocab_size: int + Number of words in vocab + word_embed_size : int + Dimension of word vector + hidden_size : int + Number of hidden units in lstm cell + dense_size : int + Number of hidden units in dense layer + num_classes : int + Number of categories + dropout : int + Dropout prob + """ + + def __init__(self, vocab_size, num_classes, word_embed_size, hidden_size, dense_size, + dropout=0., **kwargs): + super().__init__(**kwargs) + with self.name_scope(): + self.word_emb= nn.Embedding(vocab_size, word_embed_size) + self.embedding_dropout = nn.Dropout(dropout, axes=1) + self.lstm_encoder1 = rnn.LSTM(hidden_size, input_size=word_embed_size, bidirectional=True, layout='NTC') + self.ff_proj = nn.Dense(hidden_size, in_units=hidden_size * 2 * 4, flatten=False, activation='relu') + self.lstm_encoder2 = rnn.LSTM(hidden_size, input_size=hidden_size, bidirectional=True, layout='NTC') + + self.classifier = nn.HybridSequential() + if dropout: + self.classifier.add(nn.Dropout(rate=dropout)) + self.classifier.add(nn.Dense(units=hidden_size, activation='relu')) + if dropout: + self.classifier.add(nn.Dropout(rate=dropout)) + self.classifier.add(nn.Dense(units=num_classes)) + + def _soft_attention_align(self, F, x1, x2): + # attention shape: (batch, x1_seq_len, x2_seq_len) + attention = F.batch_dot(x1, x2, transpose_b=True) + + x1_align = F.batch_dot(attention.softmax(), x2) + x2_align = F.batch_dot(attention.transpose([0, 2, 1]).softmax(), x1) + + return x1_align, x2_align + + def _submul(self, F, x1, x2): + mul = x1 * x2 + sub = x1 - x2 + + return F.concat(mul, sub, dim=-1) + + def _pool(self, F, x): + p1 = x.mean(axis=1) + p2 = x.max(axis=1) + + return F.concat(p1, p2, dim=-1) + + def hybrid_forward(self, F, x1, x2): + # x1_embed x2_embed shape: (batch, seq_len, word_embed_size) + x1_embed = self.embedding_dropout(self.word_emb(x1)) + x2_embed = self.embedding_dropout(self.word_emb(x2)) + + x1_lstm_encode = self.lstm_encoder1(x1_embed) + x2_lstm_encode = self.lstm_encoder1(x2_embed) + + # attention + x1_algin, x2_algin = self._soft_attention_align(F, x1_lstm_encode, x2_lstm_encode) + + # compose + x1_combined = F.concat(x1_lstm_encode, x1_algin, + self._submul(F, x1_lstm_encode, x1_algin), dim=-1) + x2_combined = F.concat(x2_lstm_encode, x2_algin, + self._submul(F, x2_lstm_encode, x2_algin), dim=-1) + + x1_compose = self.lstm_encoder2(self.ff_proj(x1_combined)) + x2_compose = self.lstm_encoder2(self.ff_proj(x2_combined)) + + # aggregate + x1_agg = self._pool(F, x1_compose) + x2_agg = self._pool(F, x2_compose) + + # fully connection + output = self.classifier(F.concat(x1_agg, x2_agg, dim=-1)) + + return output diff --git a/scripts/natural_language_inference/index.rst b/scripts/natural_language_inference/index.rst index 2ae34d95f4..90a15980fd 100644 --- a/scripts/natural_language_inference/index.rst +++ b/scripts/natural_language_inference/index.rst @@ -23,13 +23,13 @@ Train the model without intra-sentence attention: .. code-block:: console - $ python3 main.py --train-file data/snli_1.0/train.txt --test-file data/snli_1.0/dev.txt --output-dir output/snli-basic --batch-size 32 --print-interval 5000 --lr 0.025 --epochs 300 --gpu-id 0 --dropout 0.2 --weight-decay 1e-5 + $ python3 main.py --train-file data/snli_1.0/train.txt --test-file data/snli_1.0/dev.txt --output-dir output/snli-basic --batch-size 32 --print-interval 5000 --lr 0.025 --epochs 300 --gpu-id 0 --dropout 0.2 --weight-decay 1e-5 --fix-embedding Test: .. code-block:: console - $ python3 main.py --test-file data/snli_1.0/test.txt --model-dir output/snli-basic --gpu-id 0 --mode test --output-dir output/snli-basic/test + $ python3 main.py --test-file data/snli_1.0/test.txt --model-dir output/snli-basic --gpu-id 0 --mode test --output-dir output/snli-basic/test We achieve 85.0% accuracy on the SNLI test set, comparable to 86.3% reported in the original paper. `[Training log] `__ @@ -38,7 +38,7 @@ Train the model with intra-sentence attention: .. code-block:: console - $ python3 main.py --train-file data/snli_1.0/train.txt --test-file data/snli_1.0/dev.txt --output-dir output/snli-intra --batch-size 32 --print-interval 5000 --lr 0.025 --epochs 300 --gpu-id 0 --dropout 0.2 --weight-decay 1e-5 --intra-attention + $ python3 main.py --train-file data/snli_1.0/train.txt --test-file data/snli_1.0/dev.txt --output-dir output/snli-intra --batch-size 32 --print-interval 5000 --lr 0.025 --epochs 300 --gpu-id 0 --dropout 0.2 --weight-decay 1e-5 --intra-attention --fix-embedding Test: diff --git a/scripts/natural_language_inference/main.py b/scripts/natural_language_inference/main.py index 442de57d1b..8c2d2205eb 100644 --- a/scripts/natural_language_inference/main.py +++ b/scripts/natural_language_inference/main.py @@ -43,7 +43,8 @@ from mxnet import gluon, autograd import gluonnlp as nlp -from decomposable_attention import NLIModel +from decomposable_attention import DecomposableAttentionModel +from esim import ESIMModel from dataset import read_dataset, prepare_data_loader, build_vocab from utils import logging_config @@ -67,6 +68,8 @@ def parse_args(): help='batch size') parser.add_argument('--print-interval', type=int, default=20, help='the interval of two print') + parser.add_argument('--model', choices=['da', 'esim'], default=None, required=True, + help='which model to use') parser.add_argument('--mode', choices=['train', 'test'], default='train', help='train or test') parser.add_argument('--lr', type=float, default=0.025, @@ -75,6 +78,8 @@ def parse_args(): help='maximum number of epochs to train') parser.add_argument('--embedding', default='glove', help='word embedding type') + parser.add_argument('--fix-embedding', action='store_true', + help='whether to fix pretrained word embedding') parser.add_argument('--embedding-source', default='glove.840B.300d', help='embedding file source') parser.add_argument('--embedding-size', type=int, default=300, @@ -89,6 +94,8 @@ def parse_args(): help='random seed') parser.add_argument('--dropout', type=float, default=0., help='dropout rate') + parser.add_argument('--optimizer', choices=['adam', 'adagrad'], default='adagrad', + help='optimization method') parser.add_argument('--weight-decay', type=float, default=0., help='l2 regularization weight') parser.add_argument('--intra-attention', action='store_true', @@ -107,10 +114,11 @@ def train_model(model, train_data_loader, val_data_loader, embedding, ctx, args) model.collect_params().initialize(mx.init.Normal(0.01), ctx=ctx) model.word_emb.weight.set_data(embedding.idx_to_vec) # Fix word embedding - model.word_emb.weight.grad_req = 'null' + if args.fix_embedding: + model.word_emb.weight.grad_req = 'null' loss_func = gluon.loss.SoftmaxCrossEntropyLoss() - trainer = gluon.Trainer(model.collect_params(), 'adagrad', + trainer = gluon.Trainer(model.collect_params(), args.optimizer, {'learning_rate': args.lr, 'wd': args.weight_decay, 'clip_gradient': 5}) @@ -181,6 +189,15 @@ def test_model(model, data_loader, loss_func, ctx): loss /= len(data_loader) return loss, acc +def build_model(args, vocab): + if args.model == 'da': + model = DecomposableAttentionModel(len(vocab), args.embedding_size, args.hidden_size, + args.dropout, args.intra_attention) + elif args.model == 'esim': + model = ESIMModel(len(vocab), 3, args.embedding_size, args.hidden_size, + args.dropout) + return model + def main(args): """ Entry point: train or test. @@ -211,8 +228,7 @@ def main(args): train_data_loader = prepare_data_loader(args, train_dataset, vocab) val_data_loader = prepare_data_loader(args, val_dataset, vocab, test=True) - model = NLIModel(len(vocab), args.embedding_size, args.hidden_size, - args.dropout, args.intra_attention) + model = build_model(args, vocab) train_model(model, train_data_loader, val_data_loader, vocab.embedding, ctx, args) elif args.mode == 'test': model_args = argparse.Namespace(**json.load( @@ -221,8 +237,7 @@ def main(args): open(os.path.join(args.model_dir, 'vocab.jsons')).read()) val_dataset = read_dataset(args, 'test_file') val_data_loader = prepare_data_loader(args, val_dataset, vocab, test=True) - model = NLIModel(len(vocab), model_args.embedding_size, - model_args.hidden_size, 0., model_args.intra_attention) + model = build_model(model_args, vocab) model.load_parameters(os.path.join( args.model_dir, 'checkpoints', 'valid_best.params'), ctx=ctx) loss_func = gluon.loss.SoftmaxCrossEntropyLoss()