Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Fix nmt #74

Closed
wants to merge 172 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
172 commits
Select commit Hold shift + click to select a range
8478f49
Migrate gluon-nlp-toolkit
Apr 4, 2018
3b7fc0c
add translation dataset
sxjscience Apr 12, 2018
001acc6
fix
sxjscience Apr 12, 2018
01fc2b0
fix
sxjscience Apr 12, 2018
7b967e7
fix
sxjscience Apr 12, 2018
648097e
fix
sxjscience Apr 12, 2018
0405794
fix
sxjscience Apr 12, 2018
612e338
change default behavior of skip_empty to False
sxjscience Apr 12, 2018
4dbb930
try to accelerate
sxjscience Apr 12, 2018
c7bb469
try to accelerate
sxjscience Apr 12, 2018
0b1371b
fix
sxjscience Apr 12, 2018
1464e51
try to accelerate
sxjscience Apr 12, 2018
6042052
fix
sxjscience Apr 12, 2018
9c0e6dd
fix
sxjscience Apr 12, 2018
c50e16b
fix
sxjscience Apr 12, 2018
cc1769a
add TextLineDataset
sxjscience Apr 12, 2018
f97595e
fix docstring
sxjscience Apr 12, 2018
0cae0e8
try to update FixedBucketSampler
sxjscience Apr 13, 2018
a1e803f
fix
sxjscience Apr 13, 2018
99110c3
fix
sxjscience Apr 13, 2018
0bbf0c9
Merge remote-tracking branch 'upstream/master'
sxjscience Apr 18, 2018
cf8519f
add beam search
sxjscience Apr 18, 2018
3c03dfa
hybridize
sxjscience Apr 18, 2018
8038b21
fix testing
sxjscience Apr 18, 2018
4f1d48a
return int32
sxjscience Apr 18, 2018
e93236a
fix
sxjscience Apr 18, 2018
7b200ba
fix lint
sxjscience Apr 18, 2018
f6720fd
add to __init__
sxjscience Apr 18, 2018
3c21d65
fix
sxjscience Apr 18, 2018
f726dfe
fix
sxjscience Apr 18, 2018
b374167
fix
sxjscience Apr 18, 2018
032a7e9
Merge remote-tracking branch 'upstream/master'
sxjscience Apr 19, 2018
637912d
Merge remote-tracking branch 'upstream/master'
sxjscience Apr 19, 2018
8d3fb57
try to simplify
sxjscience Apr 19, 2018
22e7ab7
fix
sxjscience Apr 19, 2018
93d7038
add translation
sxjscience Apr 19, 2018
798588d
update
sxjscience Apr 19, 2018
2e55fd7
fix
sxjscience Apr 19, 2018
243a5b5
update init
sxjscience Apr 19, 2018
41ba3de
fix
sxjscience Apr 19, 2018
ad20914
Merge remote-tracking branch 'upstream/master'
sxjscience Apr 19, 2018
ba5f7c0
Add attention
sxjscience Apr 19, 2018
f4d842e
Merge remote-tracking branch 'upstream/master'
sxjscience Apr 20, 2018
d4685c5
add nmt models
sxjscience Apr 20, 2018
d17b4ef
fix bleu
sxjscience Apr 20, 2018
1365c43
try to add tokenizeddataset
sxjscience Apr 20, 2018
1d4b7cc
fix
sxjscience Apr 20, 2018
5e78daa
fix
sxjscience Apr 20, 2018
4503847
update
sxjscience Apr 20, 2018
e21970a
update
sxjscience Apr 20, 2018
6289c20
fix
sxjscience Apr 20, 2018
705e0a6
fix
sxjscience Apr 20, 2018
58be334
fix
sxjscience Apr 20, 2018
bcd70b9
fix
sxjscience Apr 20, 2018
a554377
fix
sxjscience Apr 20, 2018
1998439
fix
sxjscience Apr 20, 2018
4aeed20
fix
sxjscience Apr 20, 2018
95ae1e9
fix
sxjscience Apr 20, 2018
8ee3eee
update
sxjscience Apr 20, 2018
fac5bea
fix
sxjscience Apr 20, 2018
4acfdd2
fix
sxjscience Apr 20, 2018
f279193
fix
sxjscience Apr 20, 2018
689371d
fix
sxjscience Apr 20, 2018
3d57383
update
sxjscience Apr 20, 2018
a81aa42
fix
sxjscience Apr 20, 2018
c091a4d
try to fix
sxjscience Apr 20, 2018
c81c0eb
update
sxjscience Apr 21, 2018
df5b436
fix
sxjscience Apr 21, 2018
c0f1a9e
fix
sxjscience Apr 21, 2018
b211218
fix
sxjscience Apr 21, 2018
d319bb7
update
sxjscience Apr 21, 2018
1c53139
fix
sxjscience Apr 21, 2018
0426de5
update
sxjscience Apr 21, 2018
7411f14
fix
sxjscience Apr 21, 2018
ca67ca0
fix
sxjscience Apr 21, 2018
0e15b75
fix
sxjscience Apr 21, 2018
aef9c9d
fix
sxjscience Apr 21, 2018
7967cd9
fix
sxjscience Apr 21, 2018
75c3c11
fix
sxjscience Apr 21, 2018
e7064c8
fix
sxjscience Apr 21, 2018
59090b7
fix
sxjscience Apr 21, 2018
87c4e49
update
sxjscience Apr 21, 2018
507ea61
fix
sxjscience Apr 21, 2018
4ccc835
fix
sxjscience Apr 21, 2018
e4b8a15
update
sxjscience Apr 21, 2018
77ab79e
fix
sxjscience Apr 21, 2018
8e37b63
fix
sxjscience Apr 21, 2018
710c8e8
fix
sxjscience Apr 21, 2018
3345024
fix
sxjscience Apr 21, 2018
7c09b2b
update
sxjscience Apr 21, 2018
3402806
update
sxjscience Apr 21, 2018
289bfe5
fix
sxjscience Apr 21, 2018
b4359c3
fix log
sxjscience Apr 21, 2018
a6f712f
fix
sxjscience Apr 21, 2018
39898bc
update
sxjscience Apr 21, 2018
e306ce7
update
sxjscience Apr 21, 2018
d9d93d3
fix
sxjscience Apr 21, 2018
adb0903
update
sxjscience Apr 21, 2018
b6ba942
fix
sxjscience Apr 21, 2018
e4c3664
update
sxjscience Apr 21, 2018
1c35f5c
fix
sxjscience Apr 21, 2018
bad4f2a
fix
sxjscience Apr 21, 2018
c1116f6
fix
sxjscience Apr 21, 2018
13ba8d7
fix
sxjscience Apr 21, 2018
7ed80fe
revert
sxjscience Apr 21, 2018
d520816
update
sxjscience Apr 21, 2018
6a26c37
fix
sxjscience Apr 21, 2018
9e339c8
fix
sxjscience Apr 21, 2018
beb5c3a
fix
sxjscience Apr 21, 2018
b5f979b
fix
sxjscience Apr 21, 2018
458f1c2
try to fix
sxjscience Apr 21, 2018
34a5319
update
sxjscience Apr 21, 2018
a4429fb
fix
sxjscience Apr 22, 2018
3208035
fix
sxjscience Apr 22, 2018
1097ca4
fix
sxjscience Apr 22, 2018
ae0c4b0
fix
sxjscience Apr 22, 2018
d98eb84
update
sxjscience Apr 22, 2018
77ecea5
fix
sxjscience Apr 22, 2018
abf78c9
fix
sxjscience Apr 22, 2018
fc30816
try to fix
sxjscience Apr 22, 2018
f3965bb
fix
sxjscience Apr 22, 2018
719fbc0
fix
sxjscience Apr 22, 2018
2100ebe
fix
sxjscience Apr 22, 2018
802a44f
fix
sxjscience Apr 22, 2018
1bd75d8
test
sxjscience Apr 22, 2018
1e2a47e
try to fix bleu
sxjscience Apr 22, 2018
bb4db30
fix
sxjscience Apr 22, 2018
6f0d995
fix
sxjscience Apr 22, 2018
0cd9f57
fi
sxjscience Apr 22, 2018
a7ed7b7
fix
sxjscience Apr 22, 2018
49913f9
fix
sxjscience Apr 22, 2018
0b28fa1
update
sxjscience Apr 22, 2018
b490ec0
fix
sxjscience Apr 22, 2018
a62b360
fix
sxjscience Apr 22, 2018
c35c8c4
fix
sxjscience Apr 22, 2018
a7a2b12
fix
sxjscience Apr 22, 2018
63fc4de
fix
sxjscience Apr 22, 2018
c54608f
fix
sxjscience Apr 22, 2018
a6584c0
fix
sxjscience Apr 22, 2018
59dbf64
fix
sxjscience Apr 22, 2018
c4775a2
fix
sxjscience Apr 22, 2018
73a6379
fix
sxjscience Apr 22, 2018
cc06927
fix
sxjscience Apr 22, 2018
455af38
add example
sxjscience Apr 22, 2018
3027687
fix
sxjscience Apr 22, 2018
d015968
fix
sxjscience Apr 22, 2018
25f8e00
fix
sxjscience Apr 22, 2018
270f8c1
fix
sxjscience Apr 22, 2018
a096dab
fix
sxjscience Apr 22, 2018
3d713c7
fix
sxjscience Apr 22, 2018
81ea8f9
fix example
sxjscience Apr 22, 2018
0cb0a48
update
sxjscience Apr 22, 2018
f309fe1
Merge remote-tracking branch 'upstream/master'
sxjscience Apr 22, 2018
bb6bae6
add rst
sxjscience Apr 22, 2018
624362a
fix
sxjscience Apr 22, 2018
1466f54
remove nose
sxjscience Apr 22, 2018
5dd5282
add doc + fix lint
sxjscience Apr 22, 2018
928c8b5
try to fix lint
sxjscience Apr 22, 2018
c41da66
fix lint
sxjscience Apr 22, 2018
1f6fa8c
fix lint
sxjscience Apr 22, 2018
d2b86a2
fix
sxjscience Apr 22, 2018
30be9f1
fix lint
sxjscience Apr 22, 2018
a53fe25
fix
sxjscience Apr 22, 2018
8ee46c0
fix
sxjscience Apr 22, 2018
a518a3c
fix lint
sxjscience Apr 22, 2018
5f06916
fix lint
sxjscience Apr 22, 2018
ccdf925
fix
sxjscience Apr 22, 2018
a8098b1
add docstring
sxjscience Apr 22, 2018
3990b61
fix
sxjscience Apr 22, 2018
9931876
Merge remote-tracking branch 'upstream/master'
sxjscience Apr 23, 2018
9d4a3aa
fix
sxjscience Apr 23, 2018
0893872
fix bug
sxjscience Apr 23, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/scripts

This file was deleted.

12 changes: 12 additions & 0 deletions docs/scripts/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Scripts
=======
Here are some useful training scripts.

.. include:: word_language_model.rst

See :download:`this example script <word_language_model.py>`

.. include:: sentiment_analysis.rst

See :download:`this example script <sentiment_analysis.py>`

341 changes: 341 additions & 0 deletions docs/scripts/sentiment_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,341 @@
"""
Fine-tune Language Model for Sentiment Analysis
===============================================

This example shows how to load a language model pre-trained on wikitext-2 in Gluon NLP Toolkit model
zoo, and reuse the language model encoder for sentiment analysis on IMDB movie reviews dataset.
"""

# coding: utf-8

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import argparse
import time
import random
import numpy as np
import mxnet as mx
from mxnet import gluon, autograd
from mxnet.gluon import Block, HybridBlock
from mxnet.gluon.data import SimpleDataset, ArrayDataset, DataLoader
import gluonnlp
from gluonnlp.data.sentiment import IMDB
from gluonnlp.data import batchify as bf
from gluonnlp.data.transforms import SpacyTokenizer, ClipSequence
from gluonnlp.data.sampler import FixedBucketSampler, SortedBucketSampler, SortedSampler
from gluonnlp.data.utils import train_valid_split
import multiprocessing as mp

np.random.seed(100)
random.seed(100)
mx.random.seed(10000)

tokenizer = SpacyTokenizer('en')
length_clip = ClipSequence(500)


def parse_args():
parser = argparse.ArgumentParser(description='MXNet Sentiment Analysis Example on IMDB. '
'We load a LSTM model that is pretrained on WikiText '
'as our encoder.')
parser.add_argument('--lm_model', type=str, default='standard_lstm_lm_200',
help='type of the pretrained model to load, can be "standard_lstm_200", '
'"standard_lstm_200", etc.')
parser.add_argument('--use-mean-pool', type=bool, default=True, help="whether to use mean pooling to aggregate the states from different timestamps.")
parser.add_argument('--no_pretrained', action='store_true', help='Turn on the option to just use the structure and not load the pretrained weights.')
parser.add_argument('--lr', type=float, default=2.5E-3,
help='initial learning rate')
parser.add_argument('--clip', type=float, default=None, help='gradient clipping')
parser.add_argument('--bucket_type', type=str, default=None,
help='Can be "fixed" or "sorted"')
parser.add_argument('--bucket_num', type=int, default=10, help='The bucket_num if bucket_type is '
'"fixed".')
parser.add_argument('--bucket_ratio', type=float, default=0.0,
help='The ratio used in the FixedBucketSampler.')
parser.add_argument('--bucket_mult', type=int, default=100,
help='The mult used in the SortedBucketSampler.')
parser.add_argument('--valid_ratio', type=float, default=0.05,
help='Proportion [0, 1] of training samples to use for validation set.')
parser.add_argument('--epochs', type=int, default=20,
help='upper epoch limit')
parser.add_argument('--batch_size', type=int, default=16, metavar='N',
help='batch size')
parser.add_argument('--dropout', type=float, default=0.,
help='dropout applied to layers (0 = no dropout)')
parser.add_argument('--log-interval', type=int, default=30, metavar='N',
help='report interval')
parser.add_argument('--save', type=str, default='model.params',
help='path to save the final model')
parser.add_argument('--gpu', type=int, default=None,
help='id of the gpu to use. Set it to empty means to use cpu.')
args = parser.parse_args()
return args


def preprocess(x):
data, label = x
label = int(label > 5)
data = vocab[length_clip(tokenizer(data))]
return data, label


def get_length(x):
return float(len(x[0]))


def load_data():
# Load the dataset
train_dataset, test_dataset = [IMDB(root='data/imdb', segment=segment) for segment in ('train', 'test')]
train_dataset, valid_dataset = train_valid_split(train_dataset, args.valid_ratio)
print("Tokenize using spaCy...")

def preprocess_dataset(dataset):
start = time.time()
with mp.Pool(8) as pool:
dataset = gluon.data.SimpleDataset(pool.map(preprocess, dataset))
lengths = gluon.data.SimpleDataset(pool.map(get_length, dataset))
end = time.time()
print('Done! Tokenizing Time={:.2f}s, #Sentences={}'.format(end - start, len(dataset)))
return dataset, lengths

# Preprocess the dataset
train_dataset, train_data_lengths = preprocess_dataset(train_dataset)
valid_dataset, valid_data_lengths = preprocess_dataset(valid_dataset)
test_dataset, test_data_lengths = preprocess_dataset(test_dataset)
return train_dataset, train_data_lengths, valid_dataset, valid_data_lengths, test_dataset, test_data_lengths


class AggregationLayer(HybridBlock):
def __init__(self, use_mean_pool=False, prefix=None, params=None):
super(AggregationLayer, self).__init__(prefix=prefix, params=params)
self._use_mean_pool = use_mean_pool

def hybrid_forward(self, F, data, valid_length):
# Data will have shape (T, N, C)
if self._use_mean_pool:
masked_encoded = F.SequenceMask(data,
sequence_length=valid_length,
use_sequence_length=True)
agg_state = F.broadcast_div(F.sum(masked_encoded, axis=0),
F.expand_dims(valid_length, axis=1))
else:
agg_state = F.SequenceLast(data,
sequence_length=valid_length,
use_sequence_length=True)
return agg_state


class SentimentNet(Block):
def __init__(self, lm_model, dropout, use_mean_pool=False, prefix=None, params=None):
super(SentimentNet, self).__init__(prefix=prefix, params=params)
self._use_mean_pool = use_mean_pool
with self.name_scope():
self.embedding = lm_model.embedding
self.encoder = lm_model.encoder
self.agg_layer = AggregationLayer(use_mean_pool=use_mean_pool)
self.out_layer = gluon.nn.HybridSequential()
with self.out_layer.name_scope():
self.out_layer.add(gluon.nn.Dropout(dropout))
self.out_layer.add(gluon.nn.Dense(1, flatten=False))

def forward(self, data, valid_length):
encoded = self.encoder(self.embedding(data)) # Shape(T, N, C)
agg_state = self.agg_layer(encoded, valid_length)
out = self.out_layer(agg_state)
return out


def evaluate(net, dataloader, context):
loss = gluon.loss.SigmoidBCELoss()
total_L = 0.0
total_sample_num = 0
total_correct_num = 0
start_log_interval_time = time.time()
print('Begin Testing...')
for i, ((data, valid_length), label) in enumerate(dataloader):
data = mx.nd.transpose(data.as_in_context(context))
valid_length = valid_length.as_in_context(context).astype(np.float32)
label = label.as_in_context(context)
output = net(data, valid_length)
L = loss(output, label)
pred = (output > 0.5).reshape((-1,))
total_L += L.sum().asscalar()
total_sample_num += label.shape[0]
total_correct_num += (pred == label).sum().asscalar()
if (i + 1) % args.log_interval == 0:
print('[Batch {}/{}] elapsed {:.2f} s'.format(
i + 1, len(dataloader), time.time() - start_log_interval_time))
start_log_interval_time = time.time()
avg_L = total_L / float(total_sample_num)
acc = total_correct_num / float(total_sample_num)
return avg_L, acc


args = parse_args()
print(args)
pretrained = not args.no_pretrained
# Load the pretrained model
if args.gpu is None:
print("Use cpu")
context = mx.cpu()
else:
print("Use gpu%d" % args.gpu)
context = mx.gpu(args.gpu)
lm_model, vocab = gluonnlp.model.get_model(name=args.lm_model,
dataset_name='wikitext-2',
pretrained=pretrained,
ctx=context,
dropout=args.dropout,
prefix='sent_net_')
# Load and preprocess the dataset
train_dataset, train_data_lengths, \
valid_dataset, valid_data_lengths, \
test_dataset, test_data_lengths = load_data()


def train():
start_pipeline_time = time.time()
net = SentimentNet(lm_model=lm_model, dropout=args.dropout, use_mean_pool=args.use_mean_pool,
prefix='sent_net_')
net.hybridize()
print(net)
if args.no_pretrained:
net.collect_params().initialize(mx.init.Xavier(), ctx=context)
else:
net.out_layer.initialize(mx.init.Xavier(), ctx=context)
trainer = gluon.Trainer(net.collect_params(), 'ftml', {'learning_rate': args.lr})
loss = gluon.loss.SigmoidBCELoss()

# Construct the DataLoader
batchify_fn = bf.Tuple(bf.Pad(axis=0, ret_length=True), bf.Stack()) # Pad data and stack label
if args.bucket_type is None:
print("Bucketing strategy is not used!")
train_dataloader = DataLoader(dataset=train_dataset,
batch_size=args.batch_size,
shuffle=True,
batchify_fn=batchify_fn)
else:
if args.bucket_type == "fixed":
print("Use FixedBucketSampler")
batch_sampler = FixedBucketSampler(train_data_lengths,
batch_size=args.batch_size,
num_buckets=args.bucket_num,
ratio=args.bucket_ratio,
shuffle=True)
print(batch_sampler.stats())
elif args.bucket_type == "sorted":
print("Use SortedBucketSampler")
batch_sampler = SortedBucketSampler(train_data_lengths,
batch_size=args.batch_size,
mult=args.bucket_mult,
shuffle=True)
else:
raise NotImplementedError
train_dataloader = DataLoader(dataset=train_dataset,
batch_sampler=batch_sampler,
batchify_fn=batchify_fn)

valid_dataloader = DataLoader(dataset=valid_dataset,
batch_size=args.batch_size,
shuffle=False,
sampler=SortedSampler(valid_data_lengths),
batchify_fn=batchify_fn)

test_dataloader = DataLoader(dataset=test_dataset,
batch_size=args.batch_size,
shuffle=False,
sampler=SortedSampler(test_data_lengths),
batchify_fn=batchify_fn)

# Training/Testing
best_valid_acc = 0
stop_early = 0
for epoch in range(args.epochs):
# Epoch training stats
start_epoch_time = time.time()
epoch_L = 0.0
epoch_sent_num = 0
epoch_wc = 0
# Log interval training stats
start_log_interval_time = time.time()
log_interval_wc = 0
log_interval_sent_num = 0
log_interval_L = 0.0

for i, ((data, valid_length), label) in enumerate(train_dataloader):
data = mx.nd.transpose(data.as_in_context(context))
label = label.as_in_context(context)
valid_length = valid_length.as_in_context(context).astype(np.float32)
wc = valid_length.sum().asscalar()
log_interval_wc += wc
epoch_wc += wc
log_interval_sent_num += data.shape[1]
epoch_sent_num += data.shape[1]
with autograd.record():
output = net(data, valid_length)
L = loss(output, label).mean()
L.backward()
# Clip gradient
if args.clip is not None:
grads = [p.grad(context) for p in net.collect_params().values()]
gluon.utils.clip_global_norm(grads, args.clip)
# Update parameter
trainer.step(1)
log_interval_L += L.asscalar()
epoch_L += L.asscalar()
if (i + 1) % args.log_interval == 0:
print('[Epoch %d Batch %d/%d] avg loss %g, throughput %gK wps' % (
epoch, i + 1, len(train_dataloader),
log_interval_L / log_interval_sent_num,
log_interval_wc / 1000 / (time.time() - start_log_interval_time)))
# Clear log interval training stats
start_log_interval_time = time.time()
log_interval_wc = 0
log_interval_sent_num = 0
log_interval_L = 0
end_epoch_time = time.time()
valid_avg_L, valid_acc = evaluate(net, valid_dataloader, context)
test_avg_L, test_acc = evaluate(net, test_dataloader, context)
print('[Epoch %d] train avg loss %g, valid acc %.4f, valid avg loss %g, test acc %.4f, test avg loss %g, throughput %gK wps' % (
epoch, epoch_L / epoch_sent_num,
valid_acc, valid_avg_L, test_acc, test_avg_L,
epoch_wc / 1000 / (end_epoch_time - start_epoch_time)))

if valid_acc < best_valid_acc:
print("No Improvement.")
stop_early += 1
if stop_early == 3:
break
else:
# Reset stop_early if the validation loss finds a new low value
print("Observe Improvement")
stop_early = 0
net.save_params(args.save)
best_valid_acc = valid_acc

net.load_params(args.save, context)
valid_avg_L, valid_acc = evaluate(net, valid_dataloader, context)
test_avg_L, test_acc = evaluate(net, test_dataloader, context)
print('Best validation loss %g, validation acc %.4f'%(valid_avg_L, valid_acc))
print('Best test loss %g, test acc %.4f'%(test_avg_L, test_acc))
print('Total time cost %.2fs'%(time.time()-start_pipeline_time))


if __name__ == "__main__":
train()

21 changes: 21 additions & 0 deletions docs/scripts/sentiment_analysis.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
Sentiment Analysis through Fine-tuning, w/ Bucketing
----------------------------------------------------

This script can be used to train a sentiment analysis model from scratch, or fine-tune a pre-trained language model.
The pre-trained language models are loaded from Gluon NLP Toolkit model zoo. It also showcases how to use different
bucketing strategies to speed up training.

Use the following command to run without using pretrained model

.. code-block:: bash

$ python sentiment_analysis.py --gpu 0 --batch_size 16 --bucket_type fixed --epochs 20 --dropout 0 --no_pretrained --lr 0.005 --valid_ratio 0.1 --save imdb_lstm_200.params # Test Accuracy 87.88

Use the following command to run with pretrained model

.. code-block:: bash

$ python sentiment_analysis.py --gpu 0 --batch_size 16 --bucket_type fixed --epochs 20 --dropout 0 --lr 0.005 --valid_ratio 0.1 --save imdb_lstm_200.params # Test Accuracy 88.46



Loading