Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[SCRIPT] Add ESIM for text matching #689

Merged
merged 5 commits into from
Jul 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@
from mxnet.gluon import nn


class NLIModel(gluon.HybridBlock):
class DecomposableAttentionModel(gluon.HybridBlock):
"""
A Decomposable Attention Model for Natural Language Inference
using intra-sentence attention.
Arxiv paper: https://arxiv.org/pdf/1606.01933.pdf
"""
def __init__(self, vocab_size, word_embed_size, hidden_size,
dropout=0., intra_attention=False, **kwargs):
super(NLIModel, self).__init__(**kwargs)
super(DecomposableAttentionModel, self).__init__(**kwargs)
self.word_embed_size = word_embed_size
self.hidden_size = hidden_size
self.use_intra_attention = intra_attention
Expand Down
117 changes: 117 additions & 0 deletions scripts/natural_language_inference/esim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# coding: utf-8

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Build an Enhancing LSTM model for Natural Language Inference
"""

__all__ = ['ESIMModel']

from mxnet.gluon import nn, rnn

EPS = 1e-12


class ESIMModel(nn.HybridBlock):
""""Enhanced LSTM for Natural Language Inference" Qian Chen,
Xiaodan Zhu, Zhenhua Ling, Si Wei, Hui Jiang, Diana Inkpen. ACL (2017)

Parameters
----------
vocab_size: int
Number of words in vocab
word_embed_size : int
Dimension of word vector
hidden_size : int
Number of hidden units in lstm cell
dense_size : int
Number of hidden units in dense layer
num_classes : int
Number of categories
dropout : int
Dropout prob
"""

def __init__(self, vocab_size, num_classes, word_embed_size, hidden_size, dense_size,
dropout=0., **kwargs):
super(ESIMModel, self).__init__(**kwargs)
with self.name_scope():
self.word_emb= nn.Embedding(vocab_size, word_embed_size)
self.embedding_dropout = nn.Dropout(dropout, axes=1)
self.lstm_encoder1 = rnn.LSTM(hidden_size, input_size=word_embed_size, bidirectional=True, layout='NTC')
self.ff_proj = nn.Dense(hidden_size, in_units=hidden_size * 2 * 4, flatten=False, activation='relu')
self.lstm_encoder2 = rnn.LSTM(hidden_size, input_size=hidden_size, bidirectional=True, layout='NTC')

self.classifier = nn.HybridSequential()
if dropout:
self.classifier.add(nn.Dropout(rate=dropout))
self.classifier.add(nn.Dense(units=hidden_size, activation='relu'))
if dropout:
self.classifier.add(nn.Dropout(rate=dropout))
self.classifier.add(nn.Dense(units=num_classes))

def _soft_attention_align(self, F, x1, x2):
# attention shape: (batch, x1_seq_len, x2_seq_len)
attention = F.batch_dot(x1, x2, transpose_b=True)

x1_align = F.batch_dot(attention.softmax(), x2)
x2_align = F.batch_dot(attention.transpose([0, 2, 1]).softmax(), x1)

return x1_align, x2_align

def _submul(self, F, x1, x2):
mul = x1 * x2
sub = x1 - x2

return F.concat(mul, sub, dim=-1)

def _pool(self, F, x):
p1 = x.mean(axis=1)
p2 = x.max(axis=1)

return F.concat(p1, p2, dim=-1)

def hybrid_forward(self, F, x1, x2):
# x1_embed x2_embed shape: (batch, seq_len, word_embed_size)
x1_embed = self.embedding_dropout(self.word_emb(x1))
x2_embed = self.embedding_dropout(self.word_emb(x2))

x1_lstm_encode = self.lstm_encoder1(x1_embed)
x2_lstm_encode = self.lstm_encoder1(x2_embed)

# attention
x1_algin, x2_algin = self._soft_attention_align(F, x1_lstm_encode, x2_lstm_encode)

# compose
x1_combined = F.concat(x1_lstm_encode, x1_algin,
self._submul(F, x1_lstm_encode, x1_algin), dim=-1)
x2_combined = F.concat(x2_lstm_encode, x2_algin,
self._submul(F, x2_lstm_encode, x2_algin), dim=-1)

x1_compose = self.lstm_encoder2(self.ff_proj(x1_combined))
x2_compose = self.lstm_encoder2(self.ff_proj(x2_combined))

# aggregate
x1_agg = self._pool(F, x1_compose)
x2_agg = self._pool(F, x2_compose)

# fully connection
output = self.classifier(F.concat(x1_agg, x2_agg, dim=-1))

return output
6 changes: 3 additions & 3 deletions scripts/natural_language_inference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ Train the model without intra-sentence attention:

.. code-block:: console

$ python3 main.py --train-file data/snli_1.0/train.txt --test-file data/snli_1.0/dev.txt --output-dir output/snli-basic --batch-size 32 --print-interval 5000 --lr 0.025 --epochs 300 --gpu-id 0 --dropout 0.2 --weight-decay 1e-5
$ python3 main.py --train-file data/snli_1.0/train.txt --test-file data/snli_1.0/dev.txt --output-dir output/snli-basic --batch-size 32 --print-interval 5000 --lr 0.025 --epochs 300 --gpu-id 0 --dropout 0.2 --weight-decay 1e-5 --fix-embedding

Test:

.. code-block:: console

$ python3 main.py --test-file data/snli_1.0/test.txt --model-dir output/snli-basic --gpu-id 0 --mode test --output-dir output/snli-basic/test
$ python3 main.py --test-file data/snli_1.0/test.txt --model-dir output/snli-basic --gpu-id 0 --mode test --output-dir output/snli-basic/test

We achieve 85.0% accuracy on the SNLI test set, comparable to 86.3% reported in the
original paper. `[Training log] <https://github.com/dmlc/web-data/blob/master/gluonnlp/logs/natural_language_inference/decomposable_attention_snli.log>`__
Expand All @@ -38,7 +38,7 @@ Train the model with intra-sentence attention:

.. code-block:: console

$ python3 main.py --train-file data/snli_1.0/train.txt --test-file data/snli_1.0/dev.txt --output-dir output/snli-intra --batch-size 32 --print-interval 5000 --lr 0.025 --epochs 300 --gpu-id 0 --dropout 0.2 --weight-decay 1e-5 --intra-attention
$ python3 main.py --train-file data/snli_1.0/train.txt --test-file data/snli_1.0/dev.txt --output-dir output/snli-intra --batch-size 32 --print-interval 5000 --lr 0.025 --epochs 300 --gpu-id 0 --dropout 0.2 --weight-decay 1e-5 --intra-attention --fix-embedding

Test:

Expand Down
29 changes: 22 additions & 7 deletions scripts/natural_language_inference/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
from mxnet import gluon, autograd
import gluonnlp as nlp

from decomposable_attention import NLIModel
from decomposable_attention import DecomposableAttentionModel
from esim import ESIMModel
from dataset import read_dataset, prepare_data_loader, build_vocab
from utils import logging_config

Expand All @@ -67,6 +68,8 @@ def parse_args():
help='batch size')
parser.add_argument('--print-interval', type=int, default=20,
help='the interval of two print')
parser.add_argument('--model', choices=['da', 'esim'], default=None, required=True,
help='which model to use')
parser.add_argument('--mode', choices=['train', 'test'], default='train',
help='train or test')
parser.add_argument('--lr', type=float, default=0.025,
Expand All @@ -75,6 +78,8 @@ def parse_args():
help='maximum number of epochs to train')
parser.add_argument('--embedding', default='glove',
help='word embedding type')
parser.add_argument('--fix-embedding', action='store_true',
help='whether to fix pretrained word embedding')
parser.add_argument('--embedding-source', default='glove.840B.300d',
help='embedding file source')
parser.add_argument('--embedding-size', type=int, default=300,
Expand All @@ -89,6 +94,8 @@ def parse_args():
help='random seed')
parser.add_argument('--dropout', type=float, default=0.,
help='dropout rate')
parser.add_argument('--optimizer', choices=['adam', 'adagrad'], default='adagrad',
help='optimization method')
parser.add_argument('--weight-decay', type=float, default=0.,
help='l2 regularization weight')
parser.add_argument('--intra-attention', action='store_true',
Expand All @@ -107,10 +114,11 @@ def train_model(model, train_data_loader, val_data_loader, embedding, ctx, args)
model.collect_params().initialize(mx.init.Normal(0.01), ctx=ctx)
model.word_emb.weight.set_data(embedding.idx_to_vec)
# Fix word embedding
model.word_emb.weight.grad_req = 'null'
if args.fix_embedding:
model.word_emb.weight.grad_req = 'null'

loss_func = gluon.loss.SoftmaxCrossEntropyLoss()
trainer = gluon.Trainer(model.collect_params(), 'adagrad',
trainer = gluon.Trainer(model.collect_params(), args.optimizer,
{'learning_rate': args.lr,
'wd': args.weight_decay,
'clip_gradient': 5})
Expand Down Expand Up @@ -181,6 +189,15 @@ def test_model(model, data_loader, loss_func, ctx):
loss /= len(data_loader)
return loss, acc

def build_model(args, vocab):
if args.model == 'da':
model = DecomposableAttentionModel(len(vocab), args.embedding_size, args.hidden_size,
args.dropout, args.intra_attention)
elif args.model == 'esim':
model = ESIMModel(len(vocab), 3, args.embedding_size, args.hidden_size,
args.dropout)
return model

def main(args):
"""
Entry point: train or test.
Expand Down Expand Up @@ -211,8 +228,7 @@ def main(args):
train_data_loader = prepare_data_loader(args, train_dataset, vocab)
val_data_loader = prepare_data_loader(args, val_dataset, vocab, test=True)

model = NLIModel(len(vocab), args.embedding_size, args.hidden_size,
args.dropout, args.intra_attention)
model = build_model(args, vocab)
train_model(model, train_data_loader, val_data_loader, vocab.embedding, ctx, args)
elif args.mode == 'test':
model_args = argparse.Namespace(**json.load(
Expand All @@ -221,8 +237,7 @@ def main(args):
open(os.path.join(args.model_dir, 'vocab.jsons')).read())
val_dataset = read_dataset(args, 'test_file')
val_data_loader = prepare_data_loader(args, val_dataset, vocab, test=True)
model = NLIModel(len(vocab), model_args.embedding_size,
model_args.hidden_size, 0., model_args.intra_attention)
model = build_model(model_args, vocab)
model.load_parameters(os.path.join(
args.model_dir, 'checkpoints', 'valid_best.params'), ctx=ctx)
loss_func = gluon.loss.SoftmaxCrossEntropyLoss()
Expand Down