diff --git a/docs/scripts b/docs/scripts deleted file mode 120000 index a339954dff..0000000000 --- a/docs/scripts +++ /dev/null @@ -1 +0,0 @@ -../scripts \ No newline at end of file diff --git a/docs/scripts/index.rst b/docs/scripts/index.rst new file mode 100644 index 0000000000..7f3e0dfa5a --- /dev/null +++ b/docs/scripts/index.rst @@ -0,0 +1,12 @@ +Scripts +======= +Here are some useful training scripts. + +.. include:: word_language_model.rst + +See :download:`this example script ` + +.. include:: sentiment_analysis.rst + +See :download:`this example script ` + diff --git a/docs/scripts/sentiment_analysis.py b/docs/scripts/sentiment_analysis.py new file mode 100644 index 0000000000..281e4c94de --- /dev/null +++ b/docs/scripts/sentiment_analysis.py @@ -0,0 +1,341 @@ +""" +Fine-tune Language Model for Sentiment Analysis +=============================================== + +This example shows how to load a language model pre-trained on wikitext-2 in Gluon NLP Toolkit model +zoo, and reuse the language model encoder for sentiment analysis on IMDB movie reviews dataset. +""" + +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import argparse +import time +import random +import numpy as np +import mxnet as mx +from mxnet import gluon, autograd +from mxnet.gluon import Block, HybridBlock +from mxnet.gluon.data import SimpleDataset, ArrayDataset, DataLoader +import gluonnlp +from gluonnlp.data.sentiment import IMDB +from gluonnlp.data import batchify as bf +from gluonnlp.data.transforms import SpacyTokenizer, ClipSequence +from gluonnlp.data.sampler import FixedBucketSampler, SortedBucketSampler, SortedSampler +from gluonnlp.data.utils import train_valid_split +import multiprocessing as mp + +np.random.seed(100) +random.seed(100) +mx.random.seed(10000) + +tokenizer = SpacyTokenizer('en') +length_clip = ClipSequence(500) + + +def parse_args(): + parser = argparse.ArgumentParser(description='MXNet Sentiment Analysis Example on IMDB. ' + 'We load a LSTM model that is pretrained on WikiText ' + 'as our encoder.') + parser.add_argument('--lm_model', type=str, default='standard_lstm_lm_200', + help='type of the pretrained model to load, can be "standard_lstm_200", ' + '"standard_lstm_200", etc.') + parser.add_argument('--use-mean-pool', type=bool, default=True, help="whether to use mean pooling to aggregate the states from different timestamps.") + parser.add_argument('--no_pretrained', action='store_true', help='Turn on the option to just use the structure and not load the pretrained weights.') + parser.add_argument('--lr', type=float, default=2.5E-3, + help='initial learning rate') + parser.add_argument('--clip', type=float, default=None, help='gradient clipping') + parser.add_argument('--bucket_type', type=str, default=None, + help='Can be "fixed" or "sorted"') + parser.add_argument('--bucket_num', type=int, default=10, help='The bucket_num if bucket_type is ' + '"fixed".') + parser.add_argument('--bucket_ratio', type=float, default=0.0, + help='The ratio used in the FixedBucketSampler.') + parser.add_argument('--bucket_mult', type=int, default=100, + help='The mult used in the SortedBucketSampler.') + parser.add_argument('--valid_ratio', type=float, default=0.05, + help='Proportion [0, 1] of training samples to use for validation set.') + parser.add_argument('--epochs', type=int, default=20, + help='upper epoch limit') + parser.add_argument('--batch_size', type=int, default=16, metavar='N', + help='batch size') + parser.add_argument('--dropout', type=float, default=0., + help='dropout applied to layers (0 = no dropout)') + parser.add_argument('--log-interval', type=int, default=30, metavar='N', + help='report interval') + parser.add_argument('--save', type=str, default='model.params', + help='path to save the final model') + parser.add_argument('--gpu', type=int, default=None, + help='id of the gpu to use. Set it to empty means to use cpu.') + args = parser.parse_args() + return args + + +def preprocess(x): + data, label = x + label = int(label > 5) + data = vocab[length_clip(tokenizer(data))] + return data, label + + +def get_length(x): + return float(len(x[0])) + + +def load_data(): + # Load the dataset + train_dataset, test_dataset = [IMDB(root='data/imdb', segment=segment) for segment in ('train', 'test')] + train_dataset, valid_dataset = train_valid_split(train_dataset, args.valid_ratio) + print("Tokenize using spaCy...") + + def preprocess_dataset(dataset): + start = time.time() + with mp.Pool(8) as pool: + dataset = gluon.data.SimpleDataset(pool.map(preprocess, dataset)) + lengths = gluon.data.SimpleDataset(pool.map(get_length, dataset)) + end = time.time() + print('Done! Tokenizing Time={:.2f}s, #Sentences={}'.format(end - start, len(dataset))) + return dataset, lengths + + # Preprocess the dataset + train_dataset, train_data_lengths = preprocess_dataset(train_dataset) + valid_dataset, valid_data_lengths = preprocess_dataset(valid_dataset) + test_dataset, test_data_lengths = preprocess_dataset(test_dataset) + return train_dataset, train_data_lengths, valid_dataset, valid_data_lengths, test_dataset, test_data_lengths + + +class AggregationLayer(HybridBlock): + def __init__(self, use_mean_pool=False, prefix=None, params=None): + super(AggregationLayer, self).__init__(prefix=prefix, params=params) + self._use_mean_pool = use_mean_pool + + def hybrid_forward(self, F, data, valid_length): + # Data will have shape (T, N, C) + if self._use_mean_pool: + masked_encoded = F.SequenceMask(data, + sequence_length=valid_length, + use_sequence_length=True) + agg_state = F.broadcast_div(F.sum(masked_encoded, axis=0), + F.expand_dims(valid_length, axis=1)) + else: + agg_state = F.SequenceLast(data, + sequence_length=valid_length, + use_sequence_length=True) + return agg_state + + +class SentimentNet(Block): + def __init__(self, lm_model, dropout, use_mean_pool=False, prefix=None, params=None): + super(SentimentNet, self).__init__(prefix=prefix, params=params) + self._use_mean_pool = use_mean_pool + with self.name_scope(): + self.embedding = lm_model.embedding + self.encoder = lm_model.encoder + self.agg_layer = AggregationLayer(use_mean_pool=use_mean_pool) + self.out_layer = gluon.nn.HybridSequential() + with self.out_layer.name_scope(): + self.out_layer.add(gluon.nn.Dropout(dropout)) + self.out_layer.add(gluon.nn.Dense(1, flatten=False)) + + def forward(self, data, valid_length): + encoded = self.encoder(self.embedding(data)) # Shape(T, N, C) + agg_state = self.agg_layer(encoded, valid_length) + out = self.out_layer(agg_state) + return out + + +def evaluate(net, dataloader, context): + loss = gluon.loss.SigmoidBCELoss() + total_L = 0.0 + total_sample_num = 0 + total_correct_num = 0 + start_log_interval_time = time.time() + print('Begin Testing...') + for i, ((data, valid_length), label) in enumerate(dataloader): + data = mx.nd.transpose(data.as_in_context(context)) + valid_length = valid_length.as_in_context(context).astype(np.float32) + label = label.as_in_context(context) + output = net(data, valid_length) + L = loss(output, label) + pred = (output > 0.5).reshape((-1,)) + total_L += L.sum().asscalar() + total_sample_num += label.shape[0] + total_correct_num += (pred == label).sum().asscalar() + if (i + 1) % args.log_interval == 0: + print('[Batch {}/{}] elapsed {:.2f} s'.format( + i + 1, len(dataloader), time.time() - start_log_interval_time)) + start_log_interval_time = time.time() + avg_L = total_L / float(total_sample_num) + acc = total_correct_num / float(total_sample_num) + return avg_L, acc + + +args = parse_args() +print(args) +pretrained = not args.no_pretrained +# Load the pretrained model +if args.gpu is None: + print("Use cpu") + context = mx.cpu() +else: + print("Use gpu%d" % args.gpu) + context = mx.gpu(args.gpu) +lm_model, vocab = gluonnlp.model.get_model(name=args.lm_model, + dataset_name='wikitext-2', + pretrained=pretrained, + ctx=context, + dropout=args.dropout, + prefix='sent_net_') +# Load and preprocess the dataset +train_dataset, train_data_lengths, \ +valid_dataset, valid_data_lengths, \ +test_dataset, test_data_lengths = load_data() + + +def train(): + start_pipeline_time = time.time() + net = SentimentNet(lm_model=lm_model, dropout=args.dropout, use_mean_pool=args.use_mean_pool, + prefix='sent_net_') + net.hybridize() + print(net) + if args.no_pretrained: + net.collect_params().initialize(mx.init.Xavier(), ctx=context) + else: + net.out_layer.initialize(mx.init.Xavier(), ctx=context) + trainer = gluon.Trainer(net.collect_params(), 'ftml', {'learning_rate': args.lr}) + loss = gluon.loss.SigmoidBCELoss() + + # Construct the DataLoader + batchify_fn = bf.Tuple(bf.Pad(axis=0, ret_length=True), bf.Stack()) # Pad data and stack label + if args.bucket_type is None: + print("Bucketing strategy is not used!") + train_dataloader = DataLoader(dataset=train_dataset, + batch_size=args.batch_size, + shuffle=True, + batchify_fn=batchify_fn) + else: + if args.bucket_type == "fixed": + print("Use FixedBucketSampler") + batch_sampler = FixedBucketSampler(train_data_lengths, + batch_size=args.batch_size, + num_buckets=args.bucket_num, + ratio=args.bucket_ratio, + shuffle=True) + print(batch_sampler.stats()) + elif args.bucket_type == "sorted": + print("Use SortedBucketSampler") + batch_sampler = SortedBucketSampler(train_data_lengths, + batch_size=args.batch_size, + mult=args.bucket_mult, + shuffle=True) + else: + raise NotImplementedError + train_dataloader = DataLoader(dataset=train_dataset, + batch_sampler=batch_sampler, + batchify_fn=batchify_fn) + + valid_dataloader = DataLoader(dataset=valid_dataset, + batch_size=args.batch_size, + shuffle=False, + sampler=SortedSampler(valid_data_lengths), + batchify_fn=batchify_fn) + + test_dataloader = DataLoader(dataset=test_dataset, + batch_size=args.batch_size, + shuffle=False, + sampler=SortedSampler(test_data_lengths), + batchify_fn=batchify_fn) + + # Training/Testing + best_valid_acc = 0 + stop_early = 0 + for epoch in range(args.epochs): + # Epoch training stats + start_epoch_time = time.time() + epoch_L = 0.0 + epoch_sent_num = 0 + epoch_wc = 0 + # Log interval training stats + start_log_interval_time = time.time() + log_interval_wc = 0 + log_interval_sent_num = 0 + log_interval_L = 0.0 + + for i, ((data, valid_length), label) in enumerate(train_dataloader): + data = mx.nd.transpose(data.as_in_context(context)) + label = label.as_in_context(context) + valid_length = valid_length.as_in_context(context).astype(np.float32) + wc = valid_length.sum().asscalar() + log_interval_wc += wc + epoch_wc += wc + log_interval_sent_num += data.shape[1] + epoch_sent_num += data.shape[1] + with autograd.record(): + output = net(data, valid_length) + L = loss(output, label).mean() + L.backward() + # Clip gradient + if args.clip is not None: + grads = [p.grad(context) for p in net.collect_params().values()] + gluon.utils.clip_global_norm(grads, args.clip) + # Update parameter + trainer.step(1) + log_interval_L += L.asscalar() + epoch_L += L.asscalar() + if (i + 1) % args.log_interval == 0: + print('[Epoch %d Batch %d/%d] avg loss %g, throughput %gK wps' % ( + epoch, i + 1, len(train_dataloader), + log_interval_L / log_interval_sent_num, + log_interval_wc / 1000 / (time.time() - start_log_interval_time))) + # Clear log interval training stats + start_log_interval_time = time.time() + log_interval_wc = 0 + log_interval_sent_num = 0 + log_interval_L = 0 + end_epoch_time = time.time() + valid_avg_L, valid_acc = evaluate(net, valid_dataloader, context) + test_avg_L, test_acc = evaluate(net, test_dataloader, context) + print('[Epoch %d] train avg loss %g, valid acc %.4f, valid avg loss %g, test acc %.4f, test avg loss %g, throughput %gK wps' % ( + epoch, epoch_L / epoch_sent_num, + valid_acc, valid_avg_L, test_acc, test_avg_L, + epoch_wc / 1000 / (end_epoch_time - start_epoch_time))) + + if valid_acc < best_valid_acc: + print("No Improvement.") + stop_early += 1 + if stop_early == 3: + break + else: + # Reset stop_early if the validation loss finds a new low value + print("Observe Improvement") + stop_early = 0 + net.save_params(args.save) + best_valid_acc = valid_acc + + net.load_params(args.save, context) + valid_avg_L, valid_acc = evaluate(net, valid_dataloader, context) + test_avg_L, test_acc = evaluate(net, test_dataloader, context) + print('Best validation loss %g, validation acc %.4f'%(valid_avg_L, valid_acc)) + print('Best test loss %g, test acc %.4f'%(test_avg_L, test_acc)) + print('Total time cost %.2fs'%(time.time()-start_pipeline_time)) + + +if __name__ == "__main__": + train() + diff --git a/docs/scripts/sentiment_analysis.rst b/docs/scripts/sentiment_analysis.rst new file mode 100644 index 0000000000..692282f6d3 --- /dev/null +++ b/docs/scripts/sentiment_analysis.rst @@ -0,0 +1,21 @@ +Sentiment Analysis through Fine-tuning, w/ Bucketing +---------------------------------------------------- + +This script can be used to train a sentiment analysis model from scratch, or fine-tune a pre-trained language model. +The pre-trained language models are loaded from Gluon NLP Toolkit model zoo. It also showcases how to use different +bucketing strategies to speed up training. + +Use the following command to run without using pretrained model + +.. code-block:: bash + + $ python sentiment_analysis.py --gpu 0 --batch_size 16 --bucket_type fixed --epochs 20 --dropout 0 --no_pretrained --lr 0.005 --valid_ratio 0.1 --save imdb_lstm_200.params # Test Accuracy 87.88 + +Use the following command to run with pretrained model + +.. code-block:: bash + + $ python sentiment_analysis.py --gpu 0 --batch_size 16 --bucket_type fixed --epochs 20 --dropout 0 --lr 0.005 --valid_ratio 0.1 --save imdb_lstm_200.params # Test Accuracy 88.46 + + + diff --git a/docs/scripts/word_language_model.py b/docs/scripts/word_language_model.py new file mode 100644 index 0000000000..b5c7d460b1 --- /dev/null +++ b/docs/scripts/word_language_model.py @@ -0,0 +1,229 @@ +""" +Word Language Model +=================== + +This example shows how to build a word-level language model on WikiText-2 with Gluon NLP Toolkit. +By using the existing data pipeline tools and building blocks, the process is greatly simplified. +""" + +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import argparse +import time +import math +import mxnet as mx +from mxnet import gluon, autograd +import os +import sys +curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +sys.path.append(os.path.join(curr_path, '..', '..')) +from gluonnlp import datasets, Vocab +from gluonnlp.models.language_model import StandardRNN, AWDRNN +from gluonnlp.data import Counter +from mxnet.gluon.data import SimpleDataset + +parser = argparse.ArgumentParser(description='MXNet Autograd RNN/LSTM Language Model on Wikitext-2.') +parser.add_argument('--model', type=str, default='lstm', + help='type of recurrent net (rnn_tanh, rnn_relu, lstm, gru)') +parser.add_argument('--emsize', type=int, default=200, + help='size of word embeddings') +parser.add_argument('--nhid', type=int, default=200, + help='number of hidden units per layer') +parser.add_argument('--nlayers', type=int, default=2, + help='number of layers') +parser.add_argument('--lr', type=float, default=20, + help='initial learning rate') +parser.add_argument('--clip', type=float, default=0.25, + help='gradient clipping') +parser.add_argument('--epochs', type=int, default=40, + help='upper epoch limit') +parser.add_argument('--batch_size', type=int, default=20, metavar='N', + help='batch size') +parser.add_argument('--bptt', type=int, default=35, + help='sequence length') +parser.add_argument('--dropout', type=float, default=0.2, + help='dropout applied to layers (0 = no dropout)') +parser.add_argument('--dropout_h', type=float, default=0.3, + help='dropout applied to hidden layer (0 = no dropout)') +parser.add_argument('--dropout_i', type=float, default=0.65, + help='dropout applied to input layer (0 = no dropout)') +parser.add_argument('--weight_dropout', type=float, default=0.0, + help='weight dropout applied to h2h weight matrix (0 = no weight dropout)') +parser.add_argument('--tied', action='store_true', + help='tie the word embedding and softmax weights') +parser.add_argument('--log-interval', type=int, default=200, metavar='N', + help='report interval') +parser.add_argument('--save', type=str, default='model.params', + help='path to save the final model') +parser.add_argument('--gctype', type=str, default='none', + help='type of gradient compression to use, \ + takes `2bit` or `none` for now.') +parser.add_argument('--gcthreshold', type=float, default=0.5, + help='threshold for 2bit gradient compression') +parser.add_argument('--eval_only', action='store_true', + help='Whether to only evaluate the trained model') +parser.add_argument('--gpus', type=str, + help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu. (the result of multi-gpu training might be slightly different compared to single-gpu training, still need to be finalized)') +args = parser.parse_args() + +print(args) + + +############################################################################### +# Load data +############################################################################### + +context = [mx.cpu()] if args.gpus is None or args.gpus == "" else \ + [mx.gpu(int(i)) for i in args.gpus.split(',')] + +assert args.batch_size % len(context) == 0, "Total batch size must be multiple of the number of devices" + +train_dataset, val_dataset, test_dataset = [datasets.WikiText2(segment=segment, + bos=None, eos='') + for segment in ['train', 'val', 'test']] + +vocab = Vocab(Counter(train_dataset[0]), padding_token=None, bos_token=None) + +train_data, val_data, test_data = [x.bptt_batchify(vocab, args.bptt, args.batch_size, + last_batch='keep') + for x in [train_dataset, val_dataset, test_dataset]] + + +############################################################################### +# Build the model +############################################################################### + + +ntokens = len(vocab) + +if args.weight_dropout > 0: + print("Use weight_drop!") + model = AWDRNN(args.model, len(vocab), args.emsize, args.nhid, args.nlayers, + args.tied, args.dropout, args.weight_dropout, args.dropout_h, args.dropout_i) +else: + model = StandardRNN(args.model, len(vocab), args.emsize, args.nhid, args.nlayers, args.dropout, + args.tied) + +model.initialize(mx.init.Xavier(), ctx=context) + + +compression_params = None if args.gctype == 'none' else {'type': args.gctype, 'threshold': args.gcthreshold} +trainer = gluon.Trainer(model.collect_params(), 'sgd', + {'learning_rate': args.lr, + 'momentum': 0, + 'wd': 0}, + compression_params=compression_params) +loss = gluon.loss.SoftmaxCrossEntropyLoss() + +############################################################################### +# Training code +############################################################################### + +def detach(hidden): + if isinstance(hidden, (tuple, list)): + hidden = [detach(i) for i in hidden] + else: + hidden = hidden.detach() + return hidden + +def evaluate(data_source, ctx): + total_L = 0.0 + ntotal = 0 + hidden = model.begin_state(args.batch_size, func=mx.nd.zeros, ctx=ctx) + for i, (data, target) in enumerate(data_source): + data = data.as_in_context(ctx) + target = target.as_in_context(ctx) + output, hidden = model(data, hidden) + L = loss(output.reshape(-3, -1), + target.reshape(-1)) + total_L += mx.nd.sum(L).asscalar() + ntotal += L.size + return total_L / ntotal + +def train(): + best_val = float("Inf") + start_train_time = time.time() + parameters = model.collect_params().values() + for epoch in range(args.epochs): + total_L, n_total = 0.0, 0 + start_epoch_time = time.time() + start_log_interval_time = time.time() + hiddens = [model.begin_state(args.batch_size//len(context), func=mx.nd.zeros, ctx=ctx) for ctx in context] + for i, (data, target) in enumerate(train_data): + data_list = gluon.utils.split_and_load(data, context, batch_axis=1, even_split=True) + target_list = gluon.utils.split_and_load(target, context, batch_axis=1, even_split=True) + hiddens = detach(hiddens) + L = 0 + Ls = [] + with autograd.record(): + for j, (X, y, h) in enumerate(zip(data_list, target_list, hiddens)): + output, h = model(X, h) + batch_L = loss(output.reshape(-3, -1), y.reshape(-1)) + L = L + batch_L.as_in_context(context[0]) / X.size + Ls.append(batch_L) + hiddens[j] = h + L.backward() + grads = [p.grad(x.context) for p in parameters for x in data_list] + gluon.utils.clip_global_norm(grads, args.clip) + + trainer.step(1) + + total_L += sum([mx.nd.sum(l).asscalar() for l in Ls]) + n_total += data.size + + if i % args.log_interval == 0 and i > 0: + cur_L = total_L / n_total + print('[Epoch %d Batch %d/%d] loss %.2f, ppl %.2f, throughput %.2f samples/s'%( + epoch, i, len(train_data), cur_L, math.exp(cur_L), + args.batch_size * args.log_interval / (time.time() - start_log_interval_time))) + total_L, n_total = 0.0, 0 + start_log_interval_time = time.time() + + mx.nd.waitall() + + print('[Epoch %d] throughput %.2f samples/s'%( + epoch, (args.batch_size * len(train_data)) / (time.time() - start_epoch_time))) + val_L = evaluate(val_data, context[0]) + print('[Epoch %d] time cost %.2fs, valid loss %.2f, valid ppl %.2f'%( + epoch, time.time()-start_epoch_time, val_L, math.exp(val_L))) + + if val_L < best_val: + best_val = val_L + test_L = evaluate(test_data, context[0]) + model.save_params(args.save) + print('test loss %.2f, test ppl %.2f'%(test_L, math.exp(test_L))) + else: + args.lr = args.lr*0.25 + print('Learning rate now %f'%(args.lr)) + trainer.set_learning_rate(args.lr) + + print('Total training throughput %.2f samples/s'%( + (args.batch_size * len(train_data) * args.epochs) / (time.time() - start_train_time))) + +if __name__ == '__main__': + start_pipeline_time = time.time() + if not args.eval_only: + train() + model.load_params(args.save, context) + val_L = evaluate(val_data) + test_L = evaluate(test_data) + print('Best validation loss %.2f, test ppl %.2f'%(val_L, math.exp(val_L))) + print('Best test loss %.2f, test ppl %.2f'%(test_L, math.exp(test_L))) + print('Total time cost %.2fs'%(time.time()-start_pipeline_time)) diff --git a/docs/scripts/word_language_model.rst b/docs/scripts/word_language_model.rst new file mode 100644 index 0000000000..5eb371fc55 --- /dev/null +++ b/docs/scripts/word_language_model.rst @@ -0,0 +1,17 @@ +Word Language Model +------------------- + +This script can be used to train language models with the given specification. + +Use the following command to run the small setting (embed and hidden size = 200) + +.. code-block:: bash + + $ python word_language_model.py --tied --gpus 0 --save wiki2_lstm_200.params # Test PPL 102.91 + +Use the following command to run the large setting (embed and hidden size = 650) + +.. code-block:: bash + + $ python word_language_model.py --emsize 650 --nhid 650 --dropout 0.5 --tied --gpus 0 --save wiki2_lstm_650.params # Test PPL 89.01 + diff --git a/docs/scripts~upstream_master b/docs/scripts~upstream_master new file mode 100644 index 0000000000..a339954dff --- /dev/null +++ b/docs/scripts~upstream_master @@ -0,0 +1 @@ +../scripts \ No newline at end of file diff --git a/scripts/beam_search/beam_search_geneartor.rst b/scripts/beam_search/beam_search_geneartor.rst new file mode 100644 index 0000000000..c0623e325e --- /dev/null +++ b/scripts/beam_search/beam_search_geneartor.rst @@ -0,0 +1,30 @@ +Beam Search Generator +--------------------- + +This script can be used to generate sentences using beam search from a pretrained language model. + +Use the following command to generate the sentences + +.. code-block:: bash + + $ python beam_search_generator.py + +Output is + +.. code-block:: log + + ['he was able to take part in an attempt to take part in the Siege of Iwo Jima , which was .', 260.96414] + ['he was able to take part in an attempt to take part in the Siege of Iwo Jima , where he .', 260.7027] + ['he was able to take part in an attempt to take part in the Siege of Iwo Jima in May .', 259.5865] + ['he was able to take part in an attempt to take part in the Siege of Iwo Jima , which would .', 259.58163] + ['he was able to take part in an attempt to take part in the Siege of Iwo Jima , but was .', 259.5562] + ['he was able to take part in an attempt to take part in the Siege of Iwo Jima , and then .', 259.5449] + ['he was able to take part in an attempt to take part in the Siege of Iwo Jima , and to .', 259.51816] + ['he was able to take part in an attempt to take part in the Siege of Iwo Jima during the Second .', 259.4851] + ['he was able to take part in an attempt to take part in the Siege of Iwo Jima , when he .', 259.41025] + ['he was able to take part in an attempt to take part in the Siege of Iwo Jima on the night .', 259.36902] + ['he was able to take part in an attempt to take part in the Siege of Iwo Jima , which had .', 259.35846] + ['he was able to take part in an attempt to take part in the Siege of Iwo Jima , and the .', 259.2447] + ['he was able to take part in an attempt to take part in the Siege of Iwo Jima , which he .', 259.23218] + ['he was able to take part in an attempt to take part in the Siege of Iwo Jima , as well .', 259.19528] + ['he was able to take part in an attempt to take part in the Siege of Iwo Jima , in which .', 259.1105] diff --git a/scripts/beam_search/beam_search_generator.py b/scripts/beam_search/beam_search_generator.py new file mode 100644 index 0000000000..a38ddc8dc4 --- /dev/null +++ b/scripts/beam_search/beam_search_generator.py @@ -0,0 +1,105 @@ +""" +Generate Sentences by Beam Search +================================== + +This example shows how to load a pretrained language model on wikitext-2 in Gluon NLP Toolkit model +zoo, and use beam search to generate sentences. +""" + +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import argparse +import mxnet as mx +import gluonnlp as nlp + + +parser = argparse.ArgumentParser(description='Generate sentences by beam search. ' + 'We load a LSTM model that is pretrained on ' + 'WikiText as our encoder.') +parser.add_argument('--lm_model', type=str, default='awd_lstm_lm_1150', + help='type of the pretrained model to load, can be "standard_lstm_lm_200", ' + '"standard_lstm_lm_650", "standard_lstm_lm_1500", ' + '"awd_lstm_lm_1150", etc.') +parser.add_argument('--beam_size', type=int, default=15, + help='Beam size in the beam search sampler.') +parser.add_argument('--alpha', type=float, default=0.0, help='Alpha in the length penalty term.') +parser.add_argument('--k', type=int, default=5, help='K in the length penalty term.') +parser.add_argument('--bos', type=str, default='he') +parser.add_argument('--eos', type=str, default='.') +parser.add_argument('--max_length', type=int, default=20, help='Maximum sentence length.') +parser.add_argument('--gpu', type=int, default=None, + help='id of the gpu to use. Set it to empty means to use cpu.') +args = parser.parse_args() +print(args) +if args.gpu is None: + ctx = mx.cpu() +else: + ctx = mx.gpu(args.gpu) + +lm_model, vocab = nlp.model.get_model(name=args.lm_model, + dataset_name='wikitext-2', + pretrained=True, + ctx=ctx) + + +def _transform_layout(data): + if isinstance(data, list): + return [_transform_layout(ele) for ele in data] + elif isinstance(data, mx.nd.NDArray): + return mx.nd.transpose(data, axes=(1, 0, 2)) + else: + raise NotImplementedError + +# Define the decoder function, we use log_softmax to map the output scores to log-likelihoods +# Also, we transform the layout to NTC +def decoder(inputs, states): + states = _transform_layout(states) + outputs, states = lm_model(mx.nd.expand_dims(inputs, axis=0), states) + states = _transform_layout(states) + return outputs[0], states + + +def generate(): + bos_id = vocab[args.bos] + eos_id = vocab[args.eos] + begin_states = lm_model.begin_state(batch_size=1, ctx=ctx) + inputs = mx.nd.full(shape=(1,), ctx=ctx, val=bos_id) + scorer = nlp.model.BeamSearchScorer(alpha=args.alpha, K=args.k) + sampler = nlp.model.BeamSearchSampler(beam_size=args.beam_size, + decoder=decoder, + eos_id=eos_id, + scorer=scorer, + max_length=args.max_length) + samples, scores, valid_lengths = sampler(inputs, begin_states) + samples = samples[0].asnumpy() + scores = scores[0].asnumpy() + valid_lengths = valid_lengths[0].asnumpy() + print("Beam Seach Parameters: beam_size={}, alpha={}, K={}".format(args.beam_size, + args.alpha, + args.k)) + print("Generation Result:") + for i in range(args.beam_size): + sentence = [vocab.idx_to_token[ele] for ele in samples[i][:valid_lengths[i]]] + print([' '.join(sentence), scores[i]]) + + +if __name__ == '__main__': + generate() + diff --git a/scripts/nmt/translation.py b/scripts/nmt/translation.py index cfd8b43dfd..f687b46e38 100644 --- a/scripts/nmt/translation.py +++ b/scripts/nmt/translation.py @@ -26,7 +26,7 @@ from mxnet.gluon import Block from mxnet.gluon import nn import mxnet as mx -from .beam_search import BeamSearchScorer, BeamSearchSampler +from gluonnlp.model.beam_search import BeamSearchScorer, BeamSearchSampler class NMTModel(Block):