Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1335][fit api]Text Sentiment Classification examples using Gluon fit() API #14350

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions example/gluon/estimator_example/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Gluon Estimator data utilities
Example modified from below link:
https://github.com/d2l-ai/d2l-en/blob/master/chapter_natural-language-processing/sentiment-analysis-rnn.md"""

import collections
from mxnet import nd
from mxnet.contrib import text
from mxnet.gluon import utils as gutils
import os
import random
import tarfile


def download_imdb(data_dir='./data'):
url = ('http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz')
sha1 = '01ada507287d82875905620988597833ad4e0903'
fname = gutils.download(url, data_dir, sha1_hash=sha1)
with tarfile.open(fname, 'r') as f:
f.extractall(data_dir)


def read_imdb(folder='train'):
data = []
for label in ['pos', 'neg']:
folder_name = os.path.join('./data/aclImdb/', folder, label)
for file in os.listdir(folder_name):
with open(os.path.join(folder_name, file), 'rb') as f:
review = f.read().decode('utf-8').replace('\n', '').lower()
data.append([review, 1 if label == 'pos' else 0])
random.shuffle(data)
return data


def get_tokenized_imdb(data):
def tokenizer(text):
return [tok.lower() for tok in text.split(' ')]

return [tokenizer(review) for review, _ in data]


def get_vocab_imdb(data):
tokenized_data = get_tokenized_imdb(data)
counter = collections.Counter([tk for st in tokenized_data for tk in st])
return text.vocab.Vocabulary(counter, min_freq=5)


def preprocess_imdb(data, vocab):
# Make the length of each comment 500 by truncating or adding 0s
max_l = 500

def pad(x):
return x[:max_l] if len(x) > max_l else x + [0] * (max_l - len(x))

tokenized_data = get_tokenized_imdb(data)
features = nd.array([pad(vocab.to_indices(x)) for x in tokenized_data])
labels = nd.array([score for _, score in data])
return features, labels
125 changes: 125 additions & 0 deletions example/gluon/estimator_example/text_sentiment_cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Gluon Text Sentiment Classification Example using CNN
Example modified from below link to demostrate using fit() API:
https://github.com/d2l-ai/d2l-en/blob/master/chapter_natural-language-processing/sentiment-analysis-cnn.md"""

import argparse
from mxnet import gluon, init, nd, cpu, gpu
from mxnet.contrib import text
from mxnet.gluon import data as gdata, loss as gloss, nn
from mxnet.gluon.estimator import estimator as est
from data_utils import download_imdb, read_imdb, get_vocab_imdb, preprocess_imdb


def corr1d(X, K):
w = K.shape[0]
Y = nd.zeros((X.shape[0] - w + 1))
for i in range(Y.shape[0]):
Y[i] = (X[i: i + w] * K).sum()
return Y


def corr1d_multi_in(X, K):
# First, we traverse along the 0th dimension (channel dimension) of X and
# K. Then, we add them together by using * to turn the result list into a
# positional argument of the add_n function
return nd.add_n(*[corr1d(x, k) for x, k in zip(X, K)])


# Model
class TextCNN(nn.Block):
def __init__(self, vocab, embed_size, kernel_sizes, num_channels,
**kwargs):
super(TextCNN, self).__init__(**kwargs)
self.embedding = nn.Embedding(len(vocab), embed_size)
# The embedding layer does not participate in training
self.constant_embedding = nn.Embedding(len(vocab), embed_size)
self.dropout = nn.Dropout(0.5)
self.decoder = nn.Dense(2)
# The max-over-time pooling layer has no weight, so it can share an
# instance
self.pool = nn.GlobalMaxPool1D()
# Create multiple one-dimensional convolutional layers
self.convs = nn.Sequential()
for c, k in zip(num_channels, kernel_sizes):
self.convs.add(nn.Conv1D(c, k, activation='relu'))

def forward(self, inputs):
# Concatenate the output of two embedding layers with shape of
# (batch size, number of words, word vector dimension) by word vector
embeddings = nd.concat(
self.embedding(inputs), self.constant_embedding(inputs), dim=2)
# According to the input format required by Conv1D, the word vector
# dimension, that is, the channel dimension of the one-dimensional
# convolutional layer, is transformed into the previous dimension
embeddings = embeddings.transpose((0, 2, 1))
# For each one-dimensional convolutional layer, after max-over-time
# pooling, an NDArray with the shape of (batch size, channel size, 1)
# can be obtained. Use the flatten function to remove the last
# dimension and then concatenate on the channel dimension
encoding = nd.concat(*[nd.flatten(
self.pool(conv(embeddings))) for conv in self.convs], dim=1)
# After applying the dropout method, use a fully connected layer to
# obtain the output
outputs = self.decoder(self.dropout(encoding))
return outputs


if __name__ == '__main__':
# Parse CLI arguments
parser = argparse.ArgumentParser(description='MXNet Gluon Text Sentiment '
'Classification Example using CNN')
parser.add_argument('--batch-size', type=int, default=64,
help='batch size for training and testing (default: 64)')
parser.add_argument('--epochs', type=int, default=5,
help='number of epochs to train (default: 5)')
parser.add_argument('--lr', type=float, default=0.001,
help='learning rate (default: 0.01)')
parser.add_argument('--use-gpu', action='store_true', default=False,
help='whether to use GPU (default: False)')
opt = parser.parse_args()

ctx = gpu(0) if opt.use_gpu else cpu()

# data
download_imdb()
train_data, test_data = read_imdb('train'), read_imdb('test')
vocab = get_vocab_imdb(train_data)
train_iter = gdata.DataLoader(gdata.ArrayDataset(
*preprocess_imdb(train_data, vocab)), opt.batch_size, shuffle=True)
test_iter = gdata.DataLoader(gdata.ArrayDataset(
*preprocess_imdb(test_data, vocab)), opt.batch_size)

# Initialize
embed_size, kernel_sizes, nums_channels = 100, [3, 4, 5], [100, 100, 100]
net = TextCNN(vocab, embed_size, kernel_sizes, nums_channels)
net.initialize(init.Xavier(), ctx=ctx)

glove_embedding = text.embedding.create(
'glove', pretrained_file_name='glove.6B.100d.txt', vocabulary=vocab)
net.embedding.weight.set_data(glove_embedding.idx_to_vec)
net.constant_embedding.weight.set_data(glove_embedding.idx_to_vec)
net.constant_embedding.collect_params().setattr('grad_req', 'null')

trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': opt.lr})
loss = gloss.SoftmaxCrossEntropyLoss()

# train
e = est.Estimator(net=net, loss=loss, trainers=trainer, context=ctx)
e.fit(train_iter, test_iter, epochs=opt.epochs, batch_size=opt.batch_size)
100 changes: 100 additions & 0 deletions example/gluon/estimator_example/text_sentiment_rnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Gluon Text Sentiment Classification Example using RNN
Example modified from below link to demostrate using fit() API:
https://github.com/d2l-ai/d2l-en/blob/master/chapter_natural-language-processing/sentiment-analysis-rnn.md"""

import argparse
from mxnet import gluon, init, nd, cpu, gpu
from mxnet.contrib import text
from mxnet.gluon import data as gdata, loss as gloss, nn, rnn
from mxnet.gluon.estimator import estimator as est
from data_utils import download_imdb, read_imdb, get_vocab_imdb, preprocess_imdb


# Model
class BiRNN(nn.Block):
def __init__(self, vocab, embed_size, num_hiddens, num_layers, **kwargs):
super(BiRNN, self).__init__(**kwargs)
self.embedding = nn.Embedding(len(vocab), embed_size)
# Set Bidirectional to True to get a bidirectional recurrent neural
# network
self.encoder = rnn.LSTM(num_hiddens, num_layers=num_layers,
bidirectional=True, input_size=embed_size)
self.decoder = nn.Dense(2)

def forward(self, inputs):
# The shape of inputs is (batch size, number of words). Because LSTM
# needs to use sequence as the first dimension, the input is
# transformed and the word feature is then extracted. The output shape
# is (number of words, batch size, word vector dimension).
embeddings = self.embedding(inputs.T)
# The shape of states is (number of words, batch size, 2 * number of
# hidden units).
states = self.encoder(embeddings)
# Concatenate the hidden states of the initial time step and final
# time step to use as the input of the fully connected layer. Its
# shape is (batch size, 4 * number of hidden units)
encoding = nd.concat(states[0], states[-1])
outputs = self.decoder(encoding)
return outputs


if __name__ == '__main__':
# Parse CLI arguments
parser = argparse.ArgumentParser(description='MXNet Gluon Text Sentiment '
'Classification Example using RNN')
parser.add_argument('--batch-size', type=int, default=64,
help='batch size for training and testing (default: 64)')
parser.add_argument('--epochs', type=int, default=5,
help='number of epochs to train (default: 5)')
parser.add_argument('--lr', type=float, default=0.01,
help='learning rate (default: 0.01)')
parser.add_argument('--use-gpu', action='store_true', default=False,
help='whether to use GPU (default: False)')
opt = parser.parse_args()

ctx = gpu(0) if opt.use_gpu else cpu()

# data
download_imdb()
train_data, test_data = read_imdb('train'), read_imdb('test')
vocab = get_vocab_imdb(train_data)

train_set = gdata.ArrayDataset(*preprocess_imdb(train_data, vocab))
test_set = gdata.ArrayDataset(*preprocess_imdb(test_data, vocab))
train_iter = gdata.DataLoader(train_set, opt.batch_size, shuffle=True)
test_iter = gdata.DataLoader(test_set, opt.batch_size)

embed_size, num_hiddens, num_layers = 100, 100, 2

net = BiRNN(vocab, embed_size, num_hiddens, num_layers)
net.initialize(init.Xavier(), ctx=ctx)

glove_embedding = text.embedding.create(
'glove', pretrained_file_name='glove.6B.100d.txt', vocabulary=vocab)

net.embedding.weight.set_data(glove_embedding.idx_to_vec)
net.embedding.collect_params().setattr('grad_req', 'null')

trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': opt.lr})
loss = gloss.SoftmaxCrossEntropyLoss()

# train
e = est.Estimator(net, loss=loss, trainers=trainer, context=ctx)
e.fit(train_iter, test_iter, epochs=opt.epochs, batch_size=opt.batch_size)