From a75158792fc0164537660a8ef6ecbc2dc934fcc4 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Thu, 2 Aug 2018 11:27:27 -0700 Subject: [PATCH 01/43] Model is assembled, but TODOs need to be addressed --- scripts/question_answering/attention_flow.py | 64 ++++ scripts/question_answering/bidaf.py | 74 +++++ scripts/question_answering/data_processing.py | 61 +++- scripts/question_answering/metric.py | 113 +++++++ .../question_answering/question_answering.py | 244 +++++++++++++++ .../question_answering/similarity_function.py | 184 +++++++++++ .../train_question_answering.py | 268 ++++++++++++++++ scripts/question_answering/utils.py | 293 ++++++++++++++++++ scripts/tests/test_question_answering.py | 88 +++++- 9 files changed, 1377 insertions(+), 12 deletions(-) create mode 100644 scripts/question_answering/attention_flow.py create mode 100644 scripts/question_answering/bidaf.py create mode 100644 scripts/question_answering/metric.py create mode 100644 scripts/question_answering/question_answering.py create mode 100644 scripts/question_answering/similarity_function.py create mode 100644 scripts/question_answering/train_question_answering.py create mode 100644 scripts/question_answering/utils.py diff --git a/scripts/question_answering/attention_flow.py b/scripts/question_answering/attention_flow.py new file mode 100644 index 0000000000..5154704772 --- /dev/null +++ b/scripts/question_answering/attention_flow.py @@ -0,0 +1,64 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from mxnet import gluon + +from .similarity_function import DotProductSimilarity + + +class AttentionFlow(gluon.HybridBlock): + """ + This ``block`` takes two ndarrays as input and returns a ndarray of attentions. + + We compute the similarity between each row in each matrix and return unnormalized similarity + scores. Because these scores are unnormalized, we don't take a mask as input; it's up to the + caller to deal with masking properly when this output is used. + + By default similarity is computed with a dot product, but you can alternatively use a + parameterized similarity function if you wish. + + + Input: + - ndarray_1: ``(batch_size, num_rows_1, embedding_dim)`` + - ndarray_2: ``(batch_size, num_rows_2, embedding_dim)`` + + Output: + - ``(batch_size, num_rows_1, num_rows_2)`` + + Parameters + ---------- + similarity_function: ``SimilarityFunction``, optional (default=``DotProductSimilarity``) + The similarity function to use when computing the attention. + """ + def __init__(self, similarity_function, **kwargs): + super(AttentionFlow, self).__init__(**kwargs) + + self._similarity_function = similarity_function or DotProductSimilarity() + + def hybrid_forward(self, F, matrix_1, matrix_2): + # pylint: disable=arguments-differ + tiled_matrix_1 = matrix_1.expand_dims(2).broadcast_to(shape=(matrix_1.shape[0], + matrix_1.shape[1], + matrix_2.shape[1], + matrix_1.shape[2])) + tiled_matrix_2 = matrix_2.expand_dims(1).broadcast_to(shape=(matrix_2.shape[0], + matrix_1.shape[1], + matrix_2.shape[1], + matrix_2.shape[2])) + return self._similarity_function(tiled_matrix_1, tiled_matrix_2) diff --git a/scripts/question_answering/bidaf.py b/scripts/question_answering/bidaf.py new file mode 100644 index 0000000000..cb2a3df132 --- /dev/null +++ b/scripts/question_answering/bidaf.py @@ -0,0 +1,74 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from mxnet import gluon + +from .attention_flow import AttentionFlow +from .utils import last_dim_softmax, weighted_sum, replace_masked_values, masked_softmax + + +class BidirectionalAttentionFlow(gluon.HybridBlock): + """ + This class implements Minjoon Seo's `Bidirectional Attention Flow model + `_ + for answering reading comprehension questions (ICLR 2017). + + """ + def __init__(self, + attention_similarity_function, + **kwargs): + super(BidirectionalAttentionFlow, self).__init__(**kwargs) + + self._matrix_attention = AttentionFlow(attention_similarity_function) + + def hybrid_forward(self, F, encoded_passage, encoded_question, + question_mask, passage_mask, batch_size, passage_length, encoding_dim): + # pylint: disable=arguments-differ + """ + """ + + # Shape: (batch_size, passage_length, question_length) + passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) + # Shape: (batch_size, passage_length, question_length) + passage_question_attention = last_dim_softmax(passage_question_similarity, question_mask) + # Shape: (batch_size, passage_length, encoding_dim) + passage_question_vectors = weighted_sum(encoded_question, passage_question_attention) + + # We replace masked values with something really negative here, so they don't affect the + # max below. + masked_similarity = replace_masked_values(passage_question_similarity, + question_mask.expand_dims(1), + -1e7) + # Shape: (batch_size, passage_length) + question_passage_similarity = masked_similarity.max(axis=-1)[0] + # Shape: (batch_size, passage_length) + question_passage_attention = masked_softmax(question_passage_similarity, passage_mask) + # Shape: (batch_size, encoding_dim) + question_passage_vector = weighted_sum(encoded_passage, question_passage_attention) + # Shape: (batch_size, passage_length, encoding_dim) + tiled_question_passage_vector = question_passage_vector.expand_dims(1).expand(batch_size, + passage_length, + encoding_dim) + + # Shape: (batch_size, passage_length, encoding_dim * 4) + final_merged_passage = F.cat([encoded_passage, + passage_question_vectors, + encoded_passage * passage_question_vectors, + encoded_passage * tiled_question_passage_vector], + dim=-1) diff --git a/scripts/question_answering/data_processing.py b/scripts/question_answering/data_processing.py index fdce8011df..adff15d278 100644 --- a/scripts/question_answering/data_processing.py +++ b/scripts/question_answering/data_processing.py @@ -31,12 +31,13 @@ from gluonnlp.data.batchify import Pad -def preprocess_dataset(dataset, question_max_length, context_max_length): +def preprocess_dataset(dataset, question_max_length, context_max_length, max_chars_per_word): """Process SQuAD dataset by creating NDArray version of data :param Dataset dataset: SQuAD dataset :param int question_max_length: Maximum length of question (padded or trimmed to that size) :param int context_max_length: Maximum length of context (padded or trimmed to that size) + :param int max_chars_per_word: Maximum length of word (padded or trimmed to that size) Returns ------- @@ -44,7 +45,8 @@ def preprocess_dataset(dataset, question_max_length, context_max_length): Dataset of preprocessed records """ vocab_provider = VocabProvider(dataset) - transformer = SQuADTransform(vocab_provider, question_max_length, context_max_length) + transformer = SQuADTransform(vocab_provider, question_max_length, + context_max_length, max_chars_per_word) processed_dataset = SimpleDataset(dataset.trasform(transformer, lazy=False)) return processed_dataset @@ -53,12 +55,13 @@ class SQuADTransform(object): """SQuADTransform class responsible for converting text data into NDArrays that can be later feed into DataProvider """ - def __init__(self, vocab_provider, question_max_length, context_max_length): + def __init__(self, vocab_provider, question_max_length, context_max_length, max_chars_per_word): self._word_vocab = vocab_provider.get_word_level_vocab() self._char_vocab = vocab_provider.get_char_level_vocab() self._question_max_length = question_max_length self._context_max_length = context_max_length + self._max_chars_per_word = max_chars_per_word self._padder = Pad() @@ -77,16 +80,19 @@ def __call__(self, record_index, question_id, question, context, answer_list, context_chars = [self._char_vocab[list(iter(word))] for word in context.split()[:self._context_max_length]] - question_words_nd = nd.array(question_words, dtype=np.int32) + question_words_nd = self._pad_to_max_word_length(question_words, self._question_max_length) question_chars_nd = self._padder(question_chars) + question_chars_nd = self._pad_to_max_char_length(question_chars_nd, + self._question_max_length) - context_words_nd = nd.array(context_words, dtype=np.int32) + context_words_nd = self._pad_to_max_word_length(context_words, self._context_max_length) context_chars_nd = self._padder(context_chars) + context_chars_nd = self._pad_to_max_char_length(context_chars_nd, self._context_max_length) answer_spans = SQuADTransform._get_answer_spans(answer_list, answer_start_list) - return record_index, question_id, question_words_nd, context_words_nd, \ - question_chars_nd, context_chars_nd, answer_spans + return (record_index, question_id, question_words_nd, context_words_nd, + question_chars_nd, context_chars_nd, answer_spans) @staticmethod def _get_answer_spans(answer_list, answer_start_list): @@ -103,6 +109,47 @@ def _get_answer_spans(answer_list, answer_start_list): return [(answer_start_list[i], answer_start_list[i] + len(answer)) for i, answer in enumerate(answer_list)] + def _pad_to_max_char_length(self, item, max_item_length): + """Pads all tokens to maximum size + + :param NDArray item: matrix of indices + :param int max_item_length: maximum length of a token + :return: + """ + # expand dimensions to 4 and turn to float32, because nd.pad can work only with 4 dims + data_expanded = item.reshape(1, 1, item.shape[0], item.shape[1]).astype(np.float32) + data_padded = nd.pad(data_expanded, + mode='constant', + pad_width=[0, 0, 0, 0, 0, max_item_length - item.shape[0], + 0, self._max_chars_per_word - item.shape[1]], + constant_value=0) + + # reshape back to original dimensions with the last dimension of max_item_length + # We also convert to float32 because it will be necessary later for processing + data_reshaped_back = data_padded.reshape(max_item_length, + self._max_chars_per_word).astype(np.float32) + return data_reshaped_back + + @staticmethod + def _pad_to_max_word_length(item, max_length): + """Pads sentences to maximum length + + :param NDArray item: vector of words + :param int max_length: Maximum length of question/context + :return: + """ + data_nd = nd.array(item, dtype=np.float32) + # expand dimensions to 4 and turn to float32, because nd.pad can work only with 4 dims + data_expanded = data_nd.reshape(1, 1, 1, data_nd.shape[0]) + data_padded = nd.pad(data_expanded, + mode='constant', + pad_width=[0, 0, 0, 0, 0, 0, 0, max_length - data_nd.shape[0]], + constant_value=0) + # reshape back to original dimensions with the last dimension of max_length + # We also convert to float32 because it will be necessary later for processing + data_reshaped_back = data_padded.reshape(max_length).astype(np.float32) + return data_reshaped_back + class VocabProvider(object): """Provides word level and character level vocabularies diff --git a/scripts/question_answering/metric.py b/scripts/question_answering/metric.py new file mode 100644 index 0000000000..e3d01c4d70 --- /dev/null +++ b/scripts/question_answering/metric.py @@ -0,0 +1,113 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Official evaluation script for v1.1 of the SQuAD dataset. """ +from __future__ import print_function +from collections import Counter +import string +import re +import argparse +import json +import sys + + +def f1_score(prediction, ground_truth): + prediction_tokens = normalize_answer(prediction).split() + ground_truth_tokens = normalize_answer(ground_truth).split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def normalize_answer(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + def remove_articles(text): + return re.sub(r'\b(a|an|the)\b', ' ', text) + + def white_space_fix(text): + return ' '.join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return ''.join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def exact_match_score(prediction, ground_truth): + return (normalize_answer(prediction) == normalize_answer(ground_truth)) + + +def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): + scores_for_ground_truths = [] + for ground_truth in ground_truths: + score = metric_fn(prediction, ground_truth) + scores_for_ground_truths.append(score) + return max(scores_for_ground_truths) + + +def evaluate(dataset, predictions): + f1 = exact_match = total = 0 + for article in dataset: + for paragraph in article['paragraphs']: + for qa in paragraph['qas']: + total += 1 + if qa['id'] not in predictions: + message = 'Unanswered question ' + qa['id'] + \ + ' will receive score 0.' + print(message, file=sys.stderr) + continue + ground_truths = list(map(lambda x: x['text'], qa['answers'])) + prediction = predictions[qa['id']] + exact_match += metric_max_over_ground_truths( + exact_match_score, prediction, ground_truths) + f1 += metric_max_over_ground_truths( + f1_score, prediction, ground_truths) + + exact_match = 100.0 * exact_match / total + f1 = 100.0 * f1 / total + + return {'exact_match': exact_match, 'f1': f1} + + +if __name__ == '__main__': + expected_version = '1.1' + parser = argparse.ArgumentParser( + description='Evaluation for SQuAD ' + expected_version) + parser.add_argument('dataset_file', help='Dataset file') + parser.add_argument('prediction_file', help='Prediction File') + args = parser.parse_args() + with open(args.dataset_file) as dataset_file: + dataset_json = json.load(dataset_file) + if (dataset_json['version'] != expected_version): + print('Evaluation expects v-' + expected_version + + ', but got dataset with v-' + dataset_json['version'], + file=sys.stderr) + dataset = dataset_json['data'] + with open(args.prediction_file) as prediction_file: + predictions = json.load(prediction_file) + print(json.dumps(evaluate(dataset, predictions))) diff --git a/scripts/question_answering/question_answering.py b/scripts/question_answering/question_answering.py new file mode 100644 index 0000000000..467c69783f --- /dev/null +++ b/scripts/question_answering/question_answering.py @@ -0,0 +1,244 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""BiDAF model blocks""" +from mxnet.gluon.nn import HybridSequential + +from scripts.question_answering.attention_flow import AttentionFlow +from scripts.question_answering.similarity_function import DotProductSimilarity + +__all__ = ['BiDAFEmbedding', 'BiDAFModelingLayer', 'BiDAFOutputLayer', 'BiDAFModel'] + +from mxnet import nd, init +from mxnet.gluon import Block +from mxnet.gluon import nn +from mxnet.gluon.rnn import LSTM + +from gluonnlp.model import ConvolutionalEncoder, Highway + + +class BiDAFEmbedding(Block): + """BiDAFEmbedding is a class describing embeddings that are separately applied to question + and context of the datasource. Both question and context are passed in two NDArrays: + 1. Matrix of words: batch_size x words_per_question/context + 2. Tensor of characters: batch_size x words_per_question/context x chars_per_word + """ + def __init__(self, word_vocab, char_vocab, contextual_embedding_nlayers=2, highway_nlayers=2, + embedding_size=100, prefix=None, params=None): + super(BiDAFEmbedding, self).__init__(prefix=prefix, params=params) + + self._char_dense_embedding = nn.Embedding(input_dim=len(char_vocab), output_dim=8) + self._char_conv_embedding = ConvolutionalEncoder( + embed_size=8, + num_filters=(100,), + ngram_filter_sizes=(5,), + num_highway=None, + conv_layer_activation='relu', + output_size=None + ) + + self._word_embedding = nn.Embedding(input_dim=len(word_vocab), output_dim=embedding_size, + weight_initializer=init.Constant( + word_vocab.embedding.idx_to_vec)) + + self._highway_network = Highway(2 * embedding_size, num_layers=highway_nlayers) + self._contextual_embedding = LSTM(hidden_size=embedding_size, + num_layers=contextual_embedding_nlayers, + bidirectional=True) + + def forward(self, x, contextual_embedding_state=None): # pylint: disable=arguments-differ + # Changing shape from NTC to TNC as most MXNet blocks work with TNC format natively + word_level_data = nd.transpose(x[0], axes=(1, 0)) + char_level_data = nd.transpose(x[1], axes=(1, 0, 2)) + + # Get word embeddings. Output is batch_size x seq_len x embedding size (100) + word_embedded = self._word_embedding(word_level_data) + + # Get char level embedding in multiple steps: + # Step 1. Embed into 8-dim vector + char_level_data = self._char_dense_embedding(char_level_data) + + # Step 2. Transpose to put seq_len first axis to later iterate over it + # In that way we can get embedding per token of every batch + char_level_data = nd.transpose(char_level_data, axes=(0, 2, 1, 3)) + + # Step 3. Iterate over tokens of each batch and apply convolutional encoder + # As a result of a single iteration, we get token embedding for every batch + token_list = [] + for token_of_all_batches in char_level_data: + token_list.append(self._char_conv_embedding(token_of_all_batches)) + + # Step 4. Concat all tokens embeddings to create a single tensor. + char_embedded = nd.concat(*token_list, dim=0) + + # Step 5. Reshape tensor to match dimensions of embedded words + char_embedded = char_embedded.reshape(shape=word_embedded.shape) + + # Concat embeddings, making channels size = 200 + highway_input = nd.concat(char_embedded, word_embedded, dim=2) + # Pass through highway, shape remains unchanged + highway_output = self._highway_network(highway_input) + + # Pass through contextual embedding, which is just bi-LSTM + ce_output, ce_state = self._contextual_embedding(highway_output, + contextual_embedding_state) + + return ce_output, ce_state + + +class BiDAFModelingLayer(Block): + """BiDAFModelingLayer implements modeling layer of BiDAF paper. It is used to scan over context + produced by Attentional Flow Layer via 2 layer bi-LSTM. + + The input data for the forward should be of dimension 8 * hidden_size (default hidden_size + is 100). + + Parameters + ---------- + + input_dim : `int`, default 100 + The number of features in the hidden state h of LSTM + nlayers : `int`, default 2 + Number of recurrent layers. + biflag: `bool`, default True + If `True`, becomes a bidirectional RNN. + dropout: `float`, default 0 + If non-zero, introduces a dropout layer on the outputs of each + RNN layer except the last layer. + prefix : `str` or None + Prefix of this `Block`. + params : `ParameterDict` or `None` + Shared Parameters for this `Block`. + """ + def __init__(self, input_dim=100, nlayers=2, biflag=True, + dropout=0.2, prefix=None, params=None): + super(BiDAFModelingLayer, self).__init__(prefix=prefix, params=params) + + self._modeling_layer = LSTM(hidden_size=input_dim, num_layers=nlayers, dropout=dropout, + bidirectional=biflag) + + def forward(self, x): # pylint: disable=arguments-differ + out = self._modeling_layer(x) + return out + + +class BiDAFOutputLayer(Block): + """ + ``BiDAFOutputLayer`` produces the final prediction of an answer. The output is a tuple of + start index and end index of the answer in the paragraph per each batch. + + It accepts 2 inputs: + `x` : the output of Attention layer of shape: + seq_max_length x batch_size x 8 * span_start_input_dim + + `m` : the output of Modeling layer of shape: + seq_max_length x batch_size x 2 * span_start_input_dim + + Parameters + ---------- + span_start_input_dim : `int`, default 100 + The number of features in the hidden state h of LSTM + units : `int`, default 10 * ``span_start_input_dim`` + Number of hidden units of `Dense` layer + nlayers : `int`, default 1 + Number of recurrent layers. + biflag: `bool`, default True + If `True`, becomes a bidirectional RNN. + dropout: `float`, default 0 + If non-zero, introduces a dropout layer on the outputs of each + RNN layer except the last layer. + prefix : `str` or None + Prefix of this `Block`. + params : `ParameterDict` or `None` + Shared Parameters for this `Block`. + """ + def __init__(self, span_start_input_dim=100, units=None, nlayers=1, biflag=True, + dropout=0.2, prefix=None, params=None): + super(BiDAFOutputLayer, self).__init__(prefix=prefix, params=params) + + units = 10 * span_start_input_dim if units is None else units + + self._start_index_dense = nn.Dense(units=units) + self._end_index_lstm = LSTM(hidden_size=span_start_input_dim, + num_layers=nlayers, dropout=dropout, bidirectional=biflag) + self._end_index_dense = nn.Dense(units=units) + + def forward(self, x, m): # pylint: disable=arguments-differ + # setting batch size as the first dimension + start_index_input = nd.transpose(nd.concat(x, m, dim=2), axes=(1, 0, 2)) + start_index_dense_output = self._start_index_dense(start_index_input) + + end_index_input_part = self._end_index_lstm(m) + end_index_input = nd.transpose(nd.concat(x, end_index_input_part, dim=2), + axes=(1, 0, 2)) + + end_index_dense_output = self._end_index_dense(end_index_input) + + # TODO: Loss function applies softmax by default, so this code is commented here + # Will need to reuse it to actually make predictions + # start_index_softmax_output = start_index_dense_output.softmax(axis=1) + # start_index = nd.argmax(start_index_softmax_output, axis=1) + # end_index_softmax_output = end_index_dense_output.softmax(axis=1) + # end_index = nd.argmax(end_index_softmax_output, axis=1) + + # producing output in shape 2 x batch_size x units + output = nd.concat(nd.expand_dims(start_index_dense_output, axis=0), + nd.expand_dims(end_index_dense_output, axis=0), dim=0) + + # transposing it to batch_size x 2 x units + return nd.transpose(output, axes=(1, 0, 2)) + + +class BiDAFModel(Block): + """Bidirectional attention flow model for Question answering + """ + + def __init__(self, word_vocab, char_vocab, options, prefix=None, params=None): + super().__init__(prefix=prefix, params=params) + + with self.name_scope(): + self._ctx_embedding = BiDAFEmbedding(word_vocab, char_vocab, + options.ctx_embedding_num_layers, + options.highway_num_layers, + options.embedding_size, + prefix="context_embedding") + self._q_embedding = BiDAFEmbedding(word_vocab, char_vocab, + options.ctx_embedding_num_layers, + options.highway_num_layers, + options.embedding_size, + prefix="question_embedding") + self._attention_layer = AttentionFlow(DotProductSimilarity()) + self._modeling_layer = BiDAFModelingLayer(input_dim=options.embedding_size, + nlayers=options.modeling_num_layers, + dropout=options.dropout) + self._output_layer = BiDAFOutputLayer(span_start_input_dim=options.embedding_size, + nlayers=options.output_num_layers, + dropout=options.dropout) + + def forward(self, x, ctx_embedding_states, q_embedding_states, *args): + ctx_embedding_output, ctx_embedding_state = self._ctx_embedding([x[2], x[4]], + ctx_embedding_states) + q_embedding_output, q_embedding_state = self._q_embedding([x[1], x[3]], + q_embedding_states) + + attention_layer_output = self._attention_layer(ctx_embedding_output, q_embedding_output) + modeling_layer_output = self._modeling_layer(attention_layer_output) + output = self._output_layer(attention_layer_output, modeling_layer_output) + + return output, ctx_embedding_state, q_embedding_state diff --git a/scripts/question_answering/similarity_function.py b/scripts/question_answering/similarity_function.py new file mode 100644 index 0000000000..0b52e24bf7 --- /dev/null +++ b/scripts/question_answering/similarity_function.py @@ -0,0 +1,184 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import mxnet as mx +from mxnet import gluon +from mxnet.gluon import nn, Parameter + +from .utils import combine_tensors + + +class SimilarityFunction(gluon.HybridBlock): + """ + A ``SimilarityFunction`` takes a pair of tensors with the same shape, and computes a similarity + function on the vectors in the last dimension. For example, the tensors might both have shape + `(batch_size, sentence_length, embedding_dim)`, and we will compute some function of the two + vectors of length `embedding_dim` for each position `(batch_size, sentence_length)`, returning a + tensor of shape `(batch_size, sentence_length)`. + + The similarity function could be as simple as a dot product, or it could be a more complex, + parameterized function. + """ + default_implementation = 'dot_product' + + def hybrid_forward(self, F, array_1, array_2): + # pylint: disable=arguments-differ + """ + Takes two tensors of the same shape, such as ``(batch_size, length_1, length_2, + embedding_dim)``. Computes a (possibly parameterized) similarity on the final dimension + and returns a tensor with one less dimension, such as ``(batch_size, length_1, length_2)``. + """ + raise NotImplementedError + + +class DotProductSimilarity(SimilarityFunction): + """ + This similarity function simply computes the dot product between each pair of vectors, with an + optional scaling to reduce the variance of the output elements. + + Parameters + ---------- + scale_output : ``bool``, optional + If ``True``, we will scale the output by ``F.sqrt(ndarray.shape[-1])``, to reduce the + variance in the result. + """ + def __init__(self, scale_output=False, **kwargs): + super(DotProductSimilarity, self).__init__(**kwargs) + self._scale_output = scale_output + + def hybrid_forward(self, F, array_1, array_2): + result = (array_1 * array_2).sum(axis=-1) + + if self._scale_output: + result *= F.sqrt(array_1.shape[-1]) + + return result + + +class CosineSimilarity(SimilarityFunction): + """ + This similarity function simply computes the cosine similarity between each pair of vectors. + It has no parameters. + """ + + def hybrid_forward(self, F, array_1, array_2): + normalized_array_1 = array_1 / F.norm(array_1, axis=-1, keepdims=True) + normalized_array_2 = array_2 / F.norm(array_2, axis=-1, keepdims=True) + return (normalized_array_1 * normalized_array_2).sum(axis=-1) + + +class BilinearSimilarity(SimilarityFunction): + """ + This similarity function performs a bilinear transformation of the two input vectors. This + function has a matrix of weights ``W`` and a bias ``b``, and the similarity between two vectors + ``x`` and ``y`` is computed as ``x^T W y + b``. + + Parameters + ---------- + array_1_dim : ``int`` + The dimension of the first ndarray, ``x``, described above. This is ``x.shape[-1]`` - the + length of the vector that will go into the similarity computation. We need this so we can + build the weight matrix correctly. + array_2_dim : ``int`` + The dimension of the second ndarray, ``y``, described above. This is ``y.shape[-1]`` - the + length of the vector that will go into the similarity computation. We need this so we can + build the weight matrix correctly. + activation : ``Activation``, optional (default=linear (i.e. no activation)) + An activation function applied after the ``x^T W y + b`` calculation. Default is no + activation. + """ + def __init__(self, + array_1_dim, + array_2_dim, + activation='linear', + **kwargs): + super(BilinearSimilarity, self).__init__(**kwargs) + self._weight_matrix = Parameter(name="weight_matrix", + shape=(array_1_dim, array_2_dim), init=mx.init.Xavier()) + self._bias = Parameter(name="bias", shape=(array_1_dim,), init=mx.init.Zero()) + + if activation == 'linear': + self._activation = None + else: + self._activation = nn.Activation(activation) + + def hybrid_forward(self, F, array_1, array_2): + intermediate = F.broadcast_mull(array_1, self._weight_matrix) + result = F.broadcast_mull(intermediate, array_2).sum(axis=-1) + + if not self._activation: + return result + + return self._activation(result + self._bias) + + +class LinearSimilarity(SimilarityFunction): + """ + This similarity function performs a dot product between a vector of weights and some + combination of the two input vectors, followed by an (optional) activation function. The + combination used is configurable. + + If the two vectors are ``x`` and ``y``, we allow the following kinds of combinations: ``x``, + ``y``, ``x*y``, ``x+y``, ``x-y``, ``x/y``, where each of those binary operations is performed + elementwise. You can list as many combinations as you want, comma separated. For example, you + might give ``x,y,x*y`` as the ``combination`` parameter to this class. The computed similarity + function would then be ``w^T [x; y; x*y] + b``, where ``w`` is a vector of weights, ``b`` is a + bias parameter, and ``[;]`` is vector concatenation. + + Parameters + ---------- + array_1_dim : ``int`` + The dimension of the first tensor, ``x``, described above. This is ``x.size()[-1]`` - the + length of the vector that will go into the similarity computation. We need this so we can + build weight vectors correctly. + array_2_dim : ``int`` + The dimension of the second tensor, ``y``, described above. This is ``y.size()[-1]`` - the + length of the vector that will go into the similarity computation. We need this so we can + build weight vectors correctly. + combination : ``str``, optional (default="x,y") + Described above. + activation : ``Activation``, optional (default=linear (i.e. no activation)) + An activation function applied after the ``w^T * [x;y] + b`` calculation. Default is no + activation. + """ + def __init__(self, + array_1_dim, + array_2_dim, + combination='x,y', + activation='linear', + **kwargs): + super(LinearSimilarity, self).__init__(**kwargs) + self._combination = combination + self._weight_matrix = Parameter(name="weight_matrix", + shape=(array_1_dim, array_2_dim), init=mx.init.Uniform()) + self._bias = Parameter(name="bias", shape=(array_1_dim,), init=mx.init.Zero()) + + if activation == 'linear': + self._activation = None + else: + self._activation = nn.Activation(activation) + + def hybrid_forward(self, F, array_1, array_2): + combined_tensors = combine_tensors(self._combination, [array_1, array_1]) + dot_product = F.broadcast_mull(combined_tensors, self._weight_matrix) + + if not self._activation: + return dot_product + + return self._activation(dot_product + self._bias) diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py new file mode 100644 index 0000000000..db3fccb6bb --- /dev/null +++ b/scripts/question_answering/train_question_answering.py @@ -0,0 +1,268 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import argparse +import numpy as np +import random +from time import time + +import mxnet as mx +from mxnet import init, autograd +from mxnet.gluon import Trainer +from mxnet.gluon.data import DataLoader, SimpleDataset, ArrayDataset +from mxnet.gluon.loss import SoftmaxCrossEntropyLoss + +import gluonnlp as nlp +from gluonnlp.data import SQuAD + +from scripts.question_answering.data_processing import VocabProvider, SQuADTransform +from scripts.question_answering.metric import f1_score, exact_match_score +from scripts.question_answering.question_answering import * +from scripts.question_answering.utils import logging_config + +np.random.seed(100) +random.seed(100) +mx.random.seed(10000) + + +def get_data(is_train, options): + """Get dataset and dataloader + + Parameters + ---------- + is_train : `bool` + If `True`, training SQuAD dataset is loaded, if `False` valiidation dataset is loaded + options : `Namespace` + Data transformation arguments + + Returns + ------- + data : Tuple + A tuple of dataset and dataloader + """ + dataset = SQuAD(segment='train' if is_train else 'val') + vocab_provider = VocabProvider(dataset) + transformer = SQuADTransform(vocab_provider, options.q_max_len, + options.ctx_max_len, options.word_max_len) + # TODO: Data processing takes too long for doing experementation + # set it to 256 to speed up thing, but need to refactor this to maybe store processed dataset + # and vocabs. 256 is not a random number, it is 2 * batch_size, so the last batch won't cause + # Invalid recurrent state shape after first batch is finished + processed_dataset = SimpleDataset([transformer(*record) for i, record in enumerate(dataset) + if i < 256]) + + data_no_label = [] + labels = [] + global_index = 0 + + # copy records to a record per answer + for r in processed_dataset: + # creating a set out of answer_span will deduplicate them + for answer_span in set(r[6]): + # need to remove question id before feeding the data to data loader + # And I also replace index with global_index when unrolling answers + data_no_label.append((global_index, r[2], r[3], r[4], r[5])) + labels.append(mx.nd.array(answer_span)) + global_index += 1 + + loadable_data = ArrayDataset(data_no_label, labels) + dataloader = DataLoader(loadable_data, batch_size=options.batch_size, shuffle=True, + last_batch='discard') + + return dataset, dataloader + + +def get_vocabs(dataset, options): + """Get word-level and character-level vocabularies + + Parameters + ---------- + dataset : `SQuAD` + SQuAD dataset to build vocab from + options : `Namespace` + Vocab building arguments + + Returns + ------- + data : Tuple + A tuple of word vocabulary and character vocabulary + """ + vocab_provider = VocabProvider(dataset) + + word_vocab = vocab_provider.get_word_level_vocab() + + word_vocab.set_embedding( + nlp.embedding.create('glove', source='glove.6B.{}d'.format(options.embedding_size))) + + char_vocab = vocab_provider.get_char_level_vocab() + return word_vocab, char_vocab + + +def get_context(options): + """Return context list to work on + + Parameters + ---------- + options : `Namespace` + Training arguments + + """ + if options.gpu is None: + ctx = mx.cpu() + print('Use CPU') + else: + ctx = mx.gpu(options.gpu) + + return ctx + + +def run_training(net, dataloader, options): + """Get word-level and character-level vocabularies + + Parameters + ---------- + net : `Block` + Network to train + dataloader : `DataLoader` + Initialized dataloader + options : `Namespace` + Training arguments + + Returns + ------- + data : Tuple + A tuple of word vocabulary and character vocabulary + """ + ctx = get_context(options) + + trainer = Trainer(net.collect_params(), args.optimizer, {'learning_rate': options.lr}) + eval_metrics = mx.metric.CompositeEvalMetric(metrics=[ + mx.metric.create(lambda label, pred: f1_score(pred, label)), + mx.metric.create(lambda label, pred: exact_match_score(pred, label)) + ]) + loss_function = SoftmaxCrossEntropyLoss() + + contextual_embedding_param_shape = (4, options.batch_size, options.embedding_size) + ctx_initial_embedding_h0 = mx.nd.random.uniform(shape=contextual_embedding_param_shape, ctx=ctx) + ctx_initial_embedding_c0 = mx.nd.random.uniform(shape=contextual_embedding_param_shape, ctx=ctx) + q_initial_embedding_h0 = mx.nd.random.uniform(shape=contextual_embedding_param_shape, ctx=ctx) + q_initial_embedding_c0 = mx.nd.random.uniform(shape=contextual_embedding_param_shape, ctx=ctx) + + ctx_embedding = [ctx_initial_embedding_h0, ctx_initial_embedding_c0] + q_embedding = [q_initial_embedding_h0, q_initial_embedding_c0] + + train_start = time() + avg_loss = mx.nd.zeros((1,), ctx=ctx) + + for epoch_id in range(args.epochs): + avg_loss *= 0 # Zero average loss of each epoch + eval_metrics.reset() # reset metrics before each epoch + + for i, (data, label) in enumerate(dataloader): + # start timing for the first batch of epoch + if i == 0: + e_start = time() + + record_index, q_words, ctx_words, q_chars, ctx_chars = data + q_words = q_words.as_in_context(ctx) + ctx_words = ctx_words.as_in_context(ctx) + q_chars = q_chars.as_in_context(ctx) + ctx_chars = ctx_chars.as_in_context(ctx) + label = label.as_in_context(ctx) + + with autograd.record(): + output, ctx_embedding, q_embedding = net((record_index, q_words, ctx_words, q_chars, + ctx_chars), ctx_embedding, q_embedding) + loss = loss_function(output, label) + + loss.backward() + trainer.step(options.batch_size) + + avg_loss += loss.mean().as_in_context(avg_loss.context) + + # TODO: Update eval metrics calculation with actual predictions + # eval_metrics.update(label, output) + + # i here would be equal to number of batches + # if multi-GPU, will also need to multiple by GPU qty + avg_loss /= i + epoch_time = time() - e_start + metrics = eval_metrics.get() + # TODO: Fix metrics, by using metric.py - original estimator + # Again, in multi-gpu environment multiple i by GPU qty + # avg_metrics = [metric / i for metric in metrics[1]] + # epoch_metrics = (metrics[0], avg_metrics) + + print("\tEPOCH {:2}: train loss {:4.2f} | batch {:4} | lr {:5.3f} | " + "Time per epoch {:5.2f} seconds" + .format(i, avg_loss.asscalar(), options.batch_size, trainer.learning_rate, + epoch_time)) + + print("Training time {:6.2f} seconds".format(time() - train_start)) + + +def get_args(): + parser = argparse.ArgumentParser(description='Question Answering example using BiDAF & SQuAD') + parser.add_argument('--epochs', type=int, default=40, help='Upper epoch limit') + parser.add_argument('--embedding_size', type=int, default=100, + help='Dimension of the word embedding') + parser.add_argument('--dropout', type=float, default=0.2, + help='dropout applied to layers (0 = no dropout)') + parser.add_argument('--ctx_embedding_num_layers', type=int, default=2, + help='Number of layers in Contextual embedding layer of BiDAF') + parser.add_argument('--highway_num_layers', type=int, default=2, + help='Number of layers in Highway layer of BiDAF') + parser.add_argument('--modeling_num_layers', type=int, default=2, + help='Number of layers in Modeling layer of BiDAF') + parser.add_argument('--output_num_layers', type=int, default=1, + help='Number of layers in Output layer of BiDAF') + parser.add_argument('--batch_size', type=int, default=128, help='Batch size') + parser.add_argument('--ctx_max_len', type=int, default=400, help='Maximum length of a context') + # TODO: Question max length in the paper is 30. Had to set it 400 to make dot_product + # similarity work + parser.add_argument('--q_max_len', type=int, default=400, help='Maximum length of a question') + parser.add_argument('--word_max_len', type=int, default=16, help='Maximum characters in a word') + parser.add_argument('--optimizer', type=str, default='adam', help='optimization algorithm') + parser.add_argument('--lr', type=float, default=1E-3, help='Initial learning rate') + parser.add_argument('--lr_update_factor', type=float, default=0.5, + help='Learning rate decay factor') + parser.add_argument('--clip', type=float, default=5.0, help='gradient clipping') + parser.add_argument('--log_interval', type=int, default=100, metavar='N', + help='report interval') + parser.add_argument('--save_dir', type=str, default='out_dir', + help='directory path to save the final model and training log') + parser.add_argument('--gpu', type=int, default=None, + help='id of the gpu to use. Set it to empty means to use cpu.') + + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = get_args() + print(args) + logging_config(args.save_dir) + + train_dataset, train_dataloader = get_data(is_train=True, options=args) + word_vocab, char_vocab = get_vocabs(train_dataset, options=args) + + net = BiDAFModel(word_vocab, char_vocab, args, prefix="bidaf") + net.initialize(init.Xavier(magnitude=2.24)) + + run_training(net, train_dataloader, args) diff --git a/scripts/question_answering/utils.py b/scripts/question_answering/utils.py new file mode 100644 index 0000000000..e9ce838599 --- /dev/null +++ b/scripts/question_answering/utils.py @@ -0,0 +1,293 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os + +import inspect +import logging +from mxnet import nd + + +def logging_config(folder=None, name=None, level=logging.DEBUG, console_level=logging.INFO, + no_console=False): + """ Config the logging. + + Parameters + ---------- + folder : str or None + name : str or None + level : int + console_level + no_console: bool + Whether to disable the console log + Returns + ------- + folder : str + Folder that the logging file will be saved into. + """ + if name is None: + name = inspect.stack()[1][1].split('.')[0] + + if folder is None: + folder = os.path.join(os.getcwd(), name) + + if not os.path.exists(folder): + os.makedirs(folder) + + # Remove all the current handlers + for handler in logging.root.handlers: + logging.root.removeHandler(handler) + + logging.root.handlers = [] + logpath = os.path.join(folder, name + '.log') + print('All Logs will be saved to {}'.format(logpath)) + logging.root.setLevel(level) + formatter = logging.Formatter('%(asctime)s - %(name)s - %(message)s') + logfile = logging.FileHandler(logpath) + logfile.setLevel(level) + logfile.setFormatter(formatter) + logging.root.addHandler(logfile) + + if not no_console: + # Initialze the console logging + logconsole = logging.StreamHandler() + logconsole.setLevel(console_level) + logconsole.setFormatter(formatter) + logging.root.addHandler(logconsole) + + return folder + + +def get_combined_dim(combination, tensor_dims): + """ + For use with :func:`combine_tensors`. This function computes the resultant dimension when + calling ``combine_tensors(combination, tensors)``, when the tensor dimension is known. This is + necessary for knowing the sizes of weight matrices when building models that use + ``combine_tensors``. + + Parameters + ---------- + combination : ``str`` + A comma-separated list of combination pieces, like ``"1,2,1*2"``, specified identically to + ``combination`` in :func:`combine_tensors`. + tensor_dims : ``List[int]`` + A list of tensor dimensions, where each dimension is from the `last axis` of the tensors + that will be input to :func:`combine_tensors`. + """ + combination = combination.replace('x', '1').replace('y', '2') + return sum([_get_combination_dim(piece, tensor_dims) for piece in combination.split(',')]) + + +def _get_combination_dim(combination, tensor_dims): + if combination.isdigit(): + index = int(combination) - 1 + return tensor_dims[index] + else: + first_tensor_dim = _get_combination_dim(combination[0], tensor_dims) + second_tensor_dim = _get_combination_dim(combination[2], tensor_dims) + operation = combination[1] + return first_tensor_dim + + +def _get_combination(combination, tensors): + if combination.isdigit(): + index = int(combination) - 1 + return tensors[index] + else: + first_tensor = _get_combination(combination[0], tensors) + second_tensor = _get_combination(combination[2], tensors) + operation = combination[1] + if operation == '*': + return first_tensor * second_tensor + elif operation == '/': + return first_tensor / second_tensor + elif operation == '+': + return first_tensor + second_tensor + elif operation == '-': + return first_tensor - second_tensor + else: + raise NotImplementedError + + +def combine_tensors(combination, tensors): + """ + Combines a list of tensors using element-wise operations and concatenation, specified by a + ``combination`` string. The string refers to (1-indexed) positions in the input tensor list, + and looks like ``"1,2,1+2,3-1"``. + + We allow the following kinds of combinations: ``x``, ``x*y``, ``x+y``, ``x-y``, and ``x/y``, + where ``x`` and ``y`` are positive integers less than or equal to ``len(tensors)``. Each of + the binary operations is performed elementwise. You can give as many combinations as you want + in the ``combination`` string. For example, for the input string ``"1,2,1*2"``, the result + would be ``[1;2;1*2]``, as you would expect, where ``[;]`` is concatenation along the last + dimension. + + If you have a fixed, known way to combine tensors that you use in a model, you should probably + just use something like ``ndarray.cat([x_tensor, y_tensor, x_tensor * y_tensor])``. This + function adds some complexity that is only necessary if you want the specific combination used + to be `configurable`. + + If you want to do any element-wise operations, the tensors involved in each element-wise + operation must have the same shape. + + This function also accepts ``x`` and ``y`` in place of ``1`` and ``2`` in the combination + string. + """ + combination = combination.replace('x', '1').replace('y', '2') + to_concatenate = [_get_combination(piece, tensors) for piece in combination.split(',')] + return nd.concat(to_concatenate, dim=-1) + + +def masked_softmax(vector, mask): + """ + ``nd.softmax(vector)`` does not work if some elements of ``vector`` should be + masked. This performs a softmax on just the non-masked portions of ``vector``. Passing + ``None`` in for the mask is also acceptable; you'll just get a regular softmax. + + We assume that both ``vector`` and ``mask`` (if given) have shape ``(batch_size, vector_dim)``. + + In the case that the input vector is completely masked, this function returns an array + of ``0.0``. This behavior may cause ``NaN`` if this is used as the last layer of a model + that uses categorical cross-entropy loss. + """ + if mask is None: + result = nd.softmax(vector, axis=-1) + else: + # To limit numerical errors from large vector elements outside the mask, we zero these out. + result = nd.softmax(vector * mask, axis=-1) + result = result * mask + result = result / (result.sum(axis=1, keepdims=True) + 1e-13) + return result + + +def masked_log_softmax(vector, mask): + """ + ``nd.log_softmax(vector)`` does not work if some elements of ``vector`` should be + masked. This performs a log_softmax on just the non-masked portions of ``vector``. Passing + ``None`` in for the mask is also acceptable; you'll just get a regular log_softmax. + + We assume that both ``vector`` and ``mask`` (if given) have shape ``(batch_size, vector_dim)``. + + In the case that the input vector is completely masked, the return value of this function is + arbitrary, but not ``nan``. You should be masking the result of whatever computation comes out + of this in that case, anyway, so the specific values returned shouldn't matter. Also, the way + that we deal with this case relies on having single-precision floats; mixing half-precision + floats with fully-masked vectors will likely give you ``nans``. + + If your logits are all extremely negative (i.e., the max value in your logit vector is -50 or + lower), the way we handle masking here could mess you up. But if you've got logit values that + extreme, you've got bigger problems than this. + """ + if mask is not None: + # vector + mask.log() is an easy way to zero out masked elements in logspace, but it + # results in nans when the whole vector is masked. We need a very small value instead of a + # zero in the mask for these cases. log(1 + 1e-45) is still basically 0, so we can safely + # just add 1e-45 before calling mask.log(). We use 1e-45 because 1e-46 is so small it + # becomes 0 - this is just the smallest value we can actually use. + vector = vector + (mask + 1e-45).log() + return nd.log_softmax(vector, axis=1) + + +def _last_dimension_applicator(function_to_apply, + tensor, + mask): + """ + Takes a tensor with 3 or more dimensions and applies a function over the last dimension. We + assume the tensor has shape ``(batch_size, ..., sequence_length)`` and that the mask (if given) + has shape ``(batch_size, sequence_length)``. We first unsqueeze and expand the mask so that it + has the same shape as the tensor, then flatten them both to be 2D, pass them through + the function and put the tensor back in its original shape. + """ + tensor_shape = tensor.shape + reshaped_tensor = tensor.reshape(-1, tensor.shape[-1]) + if mask is not None: + while mask.shape[0] < tensor.shape[0]: + mask = mask.expand_dims(1) + mask = mask.broadcast_to(tensor).contiguous().float() + mask = mask.reshape(-1, mask.shape[-1]) + reshaped_result = function_to_apply(reshaped_tensor, mask) + return reshaped_result.reshape(*tensor_shape) + + +def last_dim_softmax(tensor, mask): + """ + Takes a tensor with 3 or more dimensions and does a masked softmax over the last dimension. We + assume the tensor has shape ``(batch_size, ..., sequence_length)`` and that the mask (if given) + has shape ``(batch_size, sequence_length)``. + """ + return _last_dimension_applicator(masked_softmax, tensor, mask) + + +def last_dim_log_softmax(tensor, mask): + """ + Takes a tensor with 3 or more dimensions and does a masked log softmax over the last dimension. + We assume the tensor has shape ``(batch_size, ..., sequence_length)`` and that the mask (if given) + has shape ``(batch_size, sequence_length)``. + """ + return _last_dimension_applicator(masked_log_softmax, tensor, mask) + + +def weighted_sum(matrix, attention): + """ + Takes a matrix of vectors and a set of weights over the rows in the matrix (which we call an + "attention" vector), and returns a weighted sum of the rows in the matrix. This is the typical + computation performed after an attention mechanism. + + Note that while we call this a "matrix" of vectors and an attention "vector", we also handle + higher-order tensors. We always sum over the second-to-last dimension of the "matrix", and we + assume that all dimensions in the "matrix" prior to the last dimension are matched in the + "vector". Non-matched dimensions in the "vector" must be `directly after the batch dimension`. + + For example, say I have a "matrix" with dimensions ``(batch_size, num_queries, num_words, + embedding_dim)``. The attention "vector" then must have at least those dimensions, and could + have more. Both: + + - ``(batch_size, num_queries, num_words)`` (distribution over words for each query) + - ``(batch_size, num_documents, num_queries, num_words)`` (distribution over words in a + query for each document) + + are valid input "vectors", producing tensors of shape: + ``(batch_size, num_queries, embedding_dim)`` and + ``(batch_size, num_documents, num_queries, embedding_dim)`` respectively. + """ + + if attention.shape[0] == 2 and matrix.shape[0] == 3: + return attention.expand_dims(1).batch_dot(matrix).squeeze(1) + if attention.shape[0] == 3 and matrix.shape[0] == 3: + return attention.batch_dot(matrix) + if matrix.shape[0] - 1 < attention.shape[0]: + expanded_size = list(matrix.shape) + for i in range(attention.shape[0] - matrix.shape[0] + 1): + matrix = matrix.expand_dims(1) + expanded_size.insert(i + 1, attention.shape[i + 1]) + matrix = matrix.broadcast_to(*expanded_size) + intermediate = attention.expand_dims(-1).broadcast_to(matrix.shape) * matrix + return intermediate.sum(axis=-2) + + +def replace_masked_values(tensor, mask, replace_with): + """ + Replaces all masked values in ``tensor`` with ``replace_with``. ``mask`` must be broadcastable + to the same shape as ``tensor``. We require that ``tensor.dim() == mask.dim()``, as otherwise we + won't know which dimensions of the mask to unsqueeze. + """ + # We'll build a tensor of the same shape as `tensor`, zero out masked values, then add back in + # the `replace_with` value. + one_minus_mask = 1.0 - mask + values_to_add = replace_with * one_minus_mask + return tensor * mask + values_to_add diff --git a/scripts/tests/test_question_answering.py b/scripts/tests/test_question_answering.py index ffa9259bea..668ce9b852 100644 --- a/scripts/tests/test_question_answering.py +++ b/scripts/tests/test_question_answering.py @@ -16,21 +16,25 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import os - +from mxnet import init, nd from mxnet.gluon.data import DataLoader, SimpleDataset +import gluonnlp as nlp from gluonnlp.data import SQuAD from scripts.question_answering.data_processing import SQuADTransform, VocabProvider +from scripts.question_answering.question_answering import * question_max_length = 30 -context_max_length = 256 +context_max_length = 400 +max_chars_per_word = 16 +embedding_size = 100 def test_transform_to_nd_array(): dataset = SQuAD(segment='dev', root='tests/data/squad') vocab_provider = VocabProvider(dataset) - transformer = SQuADTransform(vocab_provider, question_max_length, context_max_length) + transformer = SQuADTransform(vocab_provider, question_max_length, + context_max_length, max_chars_per_word) record = dataset[0] transformed_record = transformer(*record) @@ -41,7 +45,8 @@ def test_transform_to_nd_array(): def test_data_loader_able_to_read(): dataset = SQuAD(segment='dev', root='tests/data/squad') vocab_provider = VocabProvider(dataset) - transformer = SQuADTransform(vocab_provider, question_max_length, context_max_length) + transformer = SQuADTransform(vocab_provider, question_max_length, + context_max_length, max_chars_per_word) record = dataset[0] processed_dataset = SimpleDataset([transformer(*record)]) @@ -65,3 +70,76 @@ def test_load_vocabs(): assert vocab_provider.get_word_level_vocab() is not None assert vocab_provider.get_char_level_vocab() is not None + + +def test_bidaf_embedding(): + batch_size = 5 + + dataset = SQuAD(segment='dev', root='tests/data/squad') + vocab_provider = VocabProvider(dataset) + transformer = SQuADTransform(vocab_provider, question_max_length, + context_max_length, max_chars_per_word) + + # for performance reason, process only batch_size # of records + processed_dataset = SimpleDataset([transformer(*record) for i, record in enumerate(dataset) + if i < batch_size]) + + # need to remove question id before feeding the data to data loader + loadable_data = SimpleDataset([(r[0], r[2], r[3], r[4], r[5], r[6]) for r in processed_dataset]) + dataloader = DataLoader(loadable_data, batch_size=5) + + word_vocab = vocab_provider.get_word_level_vocab() + word_vocab.set_embedding(nlp.embedding.create('glove', source='glove.6B.100d')) + char_vocab = vocab_provider.get_char_level_vocab() + + embedding = BiDAFEmbedding(word_vocab=word_vocab, char_vocab=char_vocab) + embedding.initialize(init.Xavier(magnitude=2.24)) + + contextual_embedding_h0 = nd.random.uniform(shape=(4, batch_size, 100)) + contextual_embedding_c0 = nd.random.uniform(shape=(4, batch_size, 100)) + + for i, data in enumerate(dataloader): + # passing only question_words_nd and question_chars_nd batch + out = embedding([data[1], data[3]], [contextual_embedding_h0, contextual_embedding_c0]) + assert out is not None + break + + +def test_modeling_layer(): + batch_size = 5 + + # The modeling layer receive input in a shape of batch_size x T x 8d + # T is the sequence length of context which is context_max_length + # d is the size of embedding, which is embedding_size + fake_data = nd.random.uniform(shape=(batch_size, context_max_length, 8 * embedding_size)) + # We assume that attention is already return data in TNC format + attention_output = nd.transpose(fake_data, axes=(1, 0, 2)) + + layer = BiDAFModelingLayer() + # The model doesn't need to know the hidden states, so I don't hold variables for the states + layer.initialize() + + output = layer(attention_output) + # According to the paper, the output should be 2d x T + assert output.shape == (context_max_length, batch_size, 2 * embedding_size) + + +def test_output_layer(): + batch_size = 5 + + # The output layer receive 2 inputs: the output of Modeling layer (context_max_length, + # batch_size, 2 * embedding_size) and the output of Attention flow layer + # (batch_size, context_max_length, 8 * embedding_size) + + # The modeling layer returns data in TNC format + modeling_output = nd.random.uniform(shape=(context_max_length, batch_size, 2 * embedding_size)) + # The layer assumes that attention is already return data in TNC format + attention_output = nd.random.uniform(shape=(context_max_length, batch_size, 8 * embedding_size)) + + layer = BiDAFOutputLayer() + # The model doesn't need to know the hidden states, so I don't hold variables for the states + layer.initialize() + + output = layer(attention_output, modeling_output) + # We expect final numbers as batch_size x 2 (first start index, second end index) + assert output.shape == (batch_size, 2) From 32787aa723c094965a911f3e2cca7d2c1d7b1e42 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Thu, 2 Aug 2018 11:30:19 -0700 Subject: [PATCH 02/43] Commenting useless code --- scripts/question_answering/train_question_answering.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index db3fccb6bb..1c12ee1771 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -203,8 +203,8 @@ def run_training(net, dataloader, options): # if multi-GPU, will also need to multiple by GPU qty avg_loss /= i epoch_time = time() - e_start - metrics = eval_metrics.get() - # TODO: Fix metrics, by using metric.py - original estimator + # TODO: Fix metrics by using metric.py - original estimator + # metrics = eval_metrics.get() # Again, in multi-gpu environment multiple i by GPU qty # avg_metrics = [metric / i for metric in metrics[1]] # epoch_metrics = (metrics[0], avg_metrics) From 23f1e9e7832ee9ef6cdee5c9d25a24e9759f02cd Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Thu, 2 Aug 2018 12:42:39 -0700 Subject: [PATCH 03/43] Fix epoch time display --- scripts/question_answering/train_question_answering.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index 1c12ee1771..63c298ccf6 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -170,7 +170,7 @@ def run_training(net, dataloader, options): train_start = time() avg_loss = mx.nd.zeros((1,), ctx=ctx) - for epoch_id in range(args.epochs): + for e in range(args.epochs): avg_loss *= 0 # Zero average loss of each epoch eval_metrics.reset() # reset metrics before each epoch @@ -202,6 +202,9 @@ def run_training(net, dataloader, options): # i here would be equal to number of batches # if multi-GPU, will also need to multiple by GPU qty avg_loss /= i + + # block the call here to get correct Time per epoch + avg_loss_scalar = avg_loss.asscalar() epoch_time = time() - e_start # TODO: Fix metrics by using metric.py - original estimator # metrics = eval_metrics.get() @@ -211,7 +214,7 @@ def run_training(net, dataloader, options): print("\tEPOCH {:2}: train loss {:4.2f} | batch {:4} | lr {:5.3f} | " "Time per epoch {:5.2f} seconds" - .format(i, avg_loss.asscalar(), options.batch_size, trainer.learning_rate, + .format(e, avg_loss_scalar, options.batch_size, trainer.learning_rate, epoch_time)) print("Training time {:6.2f} seconds".format(time() - train_start)) From 9c6c98ef186c3d80727235d3143d28d5f58454c4 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Fri, 31 Aug 2018 15:43:40 -0700 Subject: [PATCH 04/43] Multigpu support + official evalualtion --- scripts/question_answering/data_processing.py | 2 +- .../performance_evaluator.py | 130 +++++++++++ .../question_answering/question_answering.py | 11 +- .../question_answering/question_id_mapper.py | 39 ++++ .../train_question_answering.py | 214 +++++++++++------- 5 files changed, 314 insertions(+), 82 deletions(-) create mode 100644 scripts/question_answering/performance_evaluator.py create mode 100644 scripts/question_answering/question_id_mapper.py diff --git a/scripts/question_answering/data_processing.py b/scripts/question_answering/data_processing.py index adff15d278..70d5cbf76f 100644 --- a/scripts/question_answering/data_processing.py +++ b/scripts/question_answering/data_processing.py @@ -186,8 +186,8 @@ def _create_squad_vocab(tokenization_fn, dataset): all_tokens = [] for data_item in dataset: - all_tokens.extend(tokenization_fn(data_item[1])) all_tokens.extend(tokenization_fn(data_item[2])) + all_tokens.extend(tokenization_fn(data_item[3])) counter = data.count_tokens(all_tokens) vocab = Vocab(counter) diff --git a/scripts/question_answering/performance_evaluator.py b/scripts/question_answering/performance_evaluator.py new file mode 100644 index 0000000000..9a02323dcc --- /dev/null +++ b/scripts/question_answering/performance_evaluator.py @@ -0,0 +1,130 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Performance evaluator - a proxy class used for plugging in official validation script""" +from mxnet import nd, gluon +from mxnet.gluon.data import DataLoader, ArrayDataset +from scripts.question_answering.metric import evaluate + + +class PerformanceEvaluator: + def __init__(self, evaluation_dataset, json_data, question_id_mapper): + self._evaluation_dataset = evaluation_dataset + self._json_data = json_data + self._mapper = question_id_mapper + + def evaluate_performance(self, net, ctx, options): + """Get results of evaluation by official evaluation script + + Parameters + ---------- + net : `Block` + Network + ctx : `Context` + Execution context + options : `Namespace` + Training arguments + + Returns + ------- + data : `dict` + Returns a dictionary of {'exact_match': , 'f1': } + """ + + pred = {} + eval_dataset = ArrayDataset([(self._mapper.question_id_to_idx[r[1]], r[2], r[3], r[4], r[5]) + for r in self._evaluation_dataset]) + eval_dataloader = DataLoader(eval_dataset, batch_size=options.batch_size, last_batch='keep') + + for i, data in enumerate(eval_dataloader): + record_index, q_words, ctx_words, q_chars, ctx_chars = data + record_index = gluon.utils.split_and_load(record_index, ctx) + q_words = gluon.utils.split_and_load(q_words, ctx) + ctx_words = gluon.utils.split_and_load(ctx_words, ctx) + q_chars = gluon.utils.split_and_load(q_chars, ctx) + ctx_chars = gluon.utils.split_and_load(ctx_chars, ctx) + + for ri, qw, cw, qc, cc in zip(record_index, q_words, ctx_words, q_chars, ctx_chars): + out, _, _ = net((ri, qw, cw, qc, cc)) + out_per_index = out.transpose(axes=(1, 0, 2)) + start_indices = PerformanceEvaluator._get_index(out_per_index[0]) + end_indices = PerformanceEvaluator._get_index(out_per_index[1]) + + # iterate over batches + for idx, start, end in zip(data[0], start_indices, end_indices): + idx = int(idx.asscalar()) + start = int(start.asscalar()) + end = int(end.asscalar()) + pred[self._mapper.idx_to_question_id[idx]] = self.get_text_result(idx, + (start, end)) + + return evaluate(self._json_data['data'], pred) + + def get_text_result(self, idx, answer_span): + """Converts answer span into actual text from paragraph + + Parameters + ---------- + idx : `int` + Question index + answer_span : `Tuple` + Answer span (start_index, end_index) + + Returns + ------- + text : `str` + A chunk of text for provided answer_span or None if answer span cannot be provided + """ + + start, end = answer_span + + if start > end: + return None + + question_id = self._mapper.idx_to_question_id[idx] + context = self._mapper.question_id_to_context[question_id] + + # start index is above the context length - return cannot provide an answer + if start > len(context) - 1: + return '' + + # end index is above the context length - let's take answer to the end of the context + if end > len(context) - 1: + end = len(context) - 1 + + text = ' '.join(context.split()[start:end + 1]) + return text + + @staticmethod + def _get_index(prediction): + """Convert prediction to actual index in text + + Parameters + ---------- + prediction : `NDArray` + Output of the network + + Returns + ------- + indices : `NDArray` + Indices of a word in context for whole batch + """ + indices_softmax_output = prediction.softmax(axis=1) + indices = nd.argmax(indices_softmax_output, axis=1) + return indices diff --git a/scripts/question_answering/question_answering.py b/scripts/question_answering/question_answering.py index 467c69783f..673fb8ea2e 100644 --- a/scripts/question_answering/question_answering.py +++ b/scripts/question_answering/question_answering.py @@ -63,6 +63,7 @@ def __init__(self, word_vocab, char_vocab, contextual_embedding_nlayers=2, highw bidirectional=True) def forward(self, x, contextual_embedding_state=None): # pylint: disable=arguments-differ + batch_size = x[0].shape[0] # Changing shape from NTC to TNC as most MXNet blocks work with TNC format natively word_level_data = nd.transpose(x[0], axes=(1, 0)) char_level_data = nd.transpose(x[1], axes=(1, 0, 2)) @@ -95,6 +96,11 @@ def forward(self, x, contextual_embedding_state=None): # pylint: disable=argume # Pass through highway, shape remains unchanged highway_output = self._highway_network(highway_input) + # Create starting state if necessary + contextual_embedding_state = \ + self._contextual_embedding.begin_state(batch_size, ctx=highway_output.context) \ + if contextual_embedding_state is None else contextual_embedding_state + # Pass through contextual embedding, which is just bi-LSTM ce_output, ce_state = self._contextual_embedding(highway_output, contextual_embedding_state) @@ -190,7 +196,8 @@ def forward(self, x, m): # pylint: disable=arguments-differ end_index_dense_output = self._end_index_dense(end_index_input) - # TODO: Loss function applies softmax by default, so this code is commented here + # Don't need to apply softmax for training, but do need for prediction + # Maybe should use autograd properties to check it # Will need to reuse it to actually make predictions # start_index_softmax_output = start_index_dense_output.softmax(axis=1) # start_index = nd.argmax(start_index_softmax_output, axis=1) @@ -231,7 +238,7 @@ def __init__(self, word_vocab, char_vocab, options, prefix=None, params=None): nlayers=options.output_num_layers, dropout=options.dropout) - def forward(self, x, ctx_embedding_states, q_embedding_states, *args): + def forward(self, x, ctx_embedding_states=None, q_embedding_states=None, *args): ctx_embedding_output, ctx_embedding_state = self._ctx_embedding([x[2], x[4]], ctx_embedding_states) q_embedding_output, q_embedding_state = self._q_embedding([x[1], x[3]], diff --git a/scripts/question_answering/question_id_mapper.py b/scripts/question_answering/question_id_mapper.py new file mode 100644 index 0000000000..5f739b889c --- /dev/null +++ b/scripts/question_answering/question_id_mapper.py @@ -0,0 +1,39 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Question id mapper to and from int""" + + +class QuestionIdMapper: + def __init__(self, dataset): + self._question_id_to_context = {item[1]: item[3] for item in dataset} + self._question_id_to_idx = {item[1]: item[0] for item in dataset} + self._idx_to_question_id = {v: k for k, v in self._question_id_to_idx.items()} + + @property + def question_id_to_context(self): + return self._question_id_to_context + + @property + def idx_to_question_id(self): + return self._idx_to_question_id + + @property + def question_id_to_idx(self): + return self._question_id_to_idx diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index 63c298ccf6..b3881bc002 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -16,6 +16,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import logging +import pickle import argparse import numpy as np @@ -23,7 +25,7 @@ from time import time import mxnet as mx -from mxnet import init, autograd +from mxnet import gluon, init, autograd from mxnet.gluon import Trainer from mxnet.gluon.data import DataLoader, SimpleDataset, ArrayDataset from mxnet.gluon.loss import SoftmaxCrossEntropyLoss @@ -32,8 +34,9 @@ from gluonnlp.data import SQuAD from scripts.question_answering.data_processing import VocabProvider, SQuADTransform -from scripts.question_answering.metric import f1_score, exact_match_score +from scripts.question_answering.performance_evaluator import PerformanceEvaluator from scripts.question_answering.question_answering import * +from scripts.question_answering.question_id_mapper import QuestionIdMapper from scripts.question_answering.utils import logging_config np.random.seed(100) @@ -41,32 +44,45 @@ mx.random.seed(10000) -def get_data(is_train, options): - """Get dataset and dataloader +def transform_dataset(dataset, vocab_provider, options): + """Get transformed dataset Parameters ---------- - is_train : `bool` - If `True`, training SQuAD dataset is loaded, if `False` valiidation dataset is loaded + dataset : `Dataset` + Original dataset + vocab_provider : `VocabularyProvider` + Vocabulary provider options : `Namespace` Data transformation arguments Returns ------- data : Tuple - A tuple of dataset and dataloader + A tuple of dataset, QuestionIdMapper and original json data for evaluation """ - dataset = SQuAD(segment='train' if is_train else 'val') - vocab_provider = VocabProvider(dataset) transformer = SQuADTransform(vocab_provider, options.q_max_len, options.ctx_max_len, options.word_max_len) - # TODO: Data processing takes too long for doing experementation - # set it to 256 to speed up thing, but need to refactor this to maybe store processed dataset - # and vocabs. 256 is not a random number, it is 2 * batch_size, so the last batch won't cause - # Invalid recurrent state shape after first batch is finished - processed_dataset = SimpleDataset([transformer(*record) for i, record in enumerate(dataset) - if i < 256]) + processed_dataset = SimpleDataset([transformer(*record) for i, record in enumerate(dataset)]) + return processed_dataset + + +def get_record_per_answer_span(processed_dataset, options): + """Each record has multiple answers and for training purposes it is better to increase number of + records by creating a record per each answer. + Parameters + ---------- + processed_dataset : `Dataset` + Transformed dataset, ready to be trained on + options : `Namespace` + Command arguments + + Returns + ------- + data : Tuple + A tuple of dataset and dataloader + """ data_no_label = [] labels = [] global_index = 0 @@ -85,10 +101,10 @@ def get_data(is_train, options): dataloader = DataLoader(loadable_data, batch_size=options.batch_size, shuffle=True, last_batch='discard') - return dataset, dataloader + return loadable_data, dataloader -def get_vocabs(dataset, options): +def get_vocabs(vocab_provider, options): """Get word-level and character-level vocabularies Parameters @@ -103,8 +119,6 @@ def get_vocabs(dataset, options): data : Tuple A tuple of word vocabulary and character vocabulary """ - vocab_provider = VocabProvider(dataset) - word_vocab = vocab_provider.get_word_level_vocab() word_vocab.set_embedding( @@ -120,20 +134,25 @@ def get_context(options): Parameters ---------- options : `Namespace` - Training arguments + Command arguments """ + ctx = [] + if options.gpu is None: - ctx = mx.cpu() + ctx.append(mx.cpu(0)) print('Use CPU') else: - ctx = mx.gpu(options.gpu) + indices = options.gpu.split(',') + + for index in indices: + ctx.append(mx.gpu(int(index))) return ctx -def run_training(net, dataloader, options): - """Get word-level and character-level vocabularies +def run_training(net, dataloader, evaluator, ctx, options): + """Main function to do training of the network Parameters ---------- @@ -141,38 +160,23 @@ def run_training(net, dataloader, options): Network to train dataloader : `DataLoader` Initialized dataloader + evaluator: `PerformanceEvaluator` + Used to plug in official evaluation script + ctx: `Context` + Training context options : `Namespace` Training arguments - - Returns - ------- - data : Tuple - A tuple of word vocabulary and character vocabulary """ - ctx = get_context(options) trainer = Trainer(net.collect_params(), args.optimizer, {'learning_rate': options.lr}) - eval_metrics = mx.metric.CompositeEvalMetric(metrics=[ - mx.metric.create(lambda label, pred: f1_score(pred, label)), - mx.metric.create(lambda label, pred: exact_match_score(pred, label)) - ]) loss_function = SoftmaxCrossEntropyLoss() - contextual_embedding_param_shape = (4, options.batch_size, options.embedding_size) - ctx_initial_embedding_h0 = mx.nd.random.uniform(shape=contextual_embedding_param_shape, ctx=ctx) - ctx_initial_embedding_c0 = mx.nd.random.uniform(shape=contextual_embedding_param_shape, ctx=ctx) - q_initial_embedding_h0 = mx.nd.random.uniform(shape=contextual_embedding_param_shape, ctx=ctx) - q_initial_embedding_c0 = mx.nd.random.uniform(shape=contextual_embedding_param_shape, ctx=ctx) - - ctx_embedding = [ctx_initial_embedding_h0, ctx_initial_embedding_c0] - q_embedding = [q_initial_embedding_h0, q_initial_embedding_c0] - train_start = time() - avg_loss = mx.nd.zeros((1,), ctx=ctx) + avg_loss = mx.nd.zeros((1,), ctx=ctx[0]) + print("Starting training...") for e in range(args.epochs): avg_loss *= 0 # Zero average loss of each epoch - eval_metrics.reset() # reset metrics before each epoch for i, (data, label) in enumerate(dataloader): # start timing for the first batch of epoch @@ -180,48 +184,80 @@ def run_training(net, dataloader, options): e_start = time() record_index, q_words, ctx_words, q_chars, ctx_chars = data - q_words = q_words.as_in_context(ctx) - ctx_words = ctx_words.as_in_context(ctx) - q_chars = q_chars.as_in_context(ctx) - ctx_chars = ctx_chars.as_in_context(ctx) - label = label.as_in_context(ctx) - - with autograd.record(): - output, ctx_embedding, q_embedding = net((record_index, q_words, ctx_words, q_chars, - ctx_chars), ctx_embedding, q_embedding) - loss = loss_function(output, label) - - loss.backward() + record_index = gluon.utils.split_and_load(record_index, ctx) + q_words = gluon.utils.split_and_load(q_words, ctx) + ctx_words = gluon.utils.split_and_load(ctx_words, ctx) + q_chars = gluon.utils.split_and_load(q_chars, ctx) + ctx_chars = gluon.utils.split_and_load(ctx_chars, ctx) + label = gluon.utils.split_and_load(label, ctx) + + # Wait for completion of previous iteration to avoid unnecessary memory allocation + mx.nd.waitall() + losses = [] + + for ri, qw, cw, qc, cc, l in zip(record_index, q_words, ctx_words, + q_chars, ctx_chars, label): + with autograd.record(): + o, _, _ = net((ri, qw, cw, qc, cc)) + loss = loss_function(o, l) + losses.append(loss) + + for l in losses: + l.backward() + trainer.step(options.batch_size) - avg_loss += loss.mean().as_in_context(avg_loss.context) + for l in losses: + avg_loss += l.mean().as_in_context(avg_loss.context) - # TODO: Update eval metrics calculation with actual predictions - # eval_metrics.update(label, output) + eval_results = evaluator.evaluate_performance(net, ctx, options) - # i here would be equal to number of batches - # if multi-GPU, will also need to multiple by GPU qty - avg_loss /= i + avg_loss /= (i * len(ctx)) # block the call here to get correct Time per epoch avg_loss_scalar = avg_loss.asscalar() epoch_time = time() - e_start - # TODO: Fix metrics by using metric.py - original estimator - # metrics = eval_metrics.get() - # Again, in multi-gpu environment multiple i by GPU qty - # avg_metrics = [metric / i for metric in metrics[1]] - # epoch_metrics = (metrics[0], avg_metrics) print("\tEPOCH {:2}: train loss {:4.2f} | batch {:4} | lr {:5.3f} | " - "Time per epoch {:5.2f} seconds" + "Time per epoch {:5.2f} seconds | {}" .format(e, avg_loss_scalar, options.batch_size, trainer.learning_rate, - epoch_time)) + epoch_time, eval_results)) print("Training time {:6.2f} seconds".format(time() - train_start)) +def save_transformed_dataset(dataset, options): + """Save processed dataset into a file. + + Parameters + ---------- + dataset : `Dataset` + Dataset to save + options : `Namespace` + Saving arguments + """ + pickle.dump(dataset, open(options.preprocessed_dataset_path, "wb")) + + +def load_transformed_dataset(options): + """Loads already preprocessed dataset from disk + + Parameters + ---------- + options : `Namespace` + Loading arguments + """ + processed_dataset = pickle.load(open(options.preprocessed_dataset_path, "rb")) + return processed_dataset + def get_args(): + """Get console arguments + """ parser = argparse.ArgumentParser(description='Question Answering example using BiDAF & SQuAD') + parser.add_argument('--preprocess', type=bool, default=False, help='Preprocess dataset only') + parser.add_argument('--train', type=bool, default=True, help='Run training') + parser.add_argument('--preprocessed_dataset_path', type=str, + default="preprocessed_dataset.p", help='Path to preprocessed dataset') parser.add_argument('--epochs', type=int, default=40, help='Upper epoch limit') parser.add_argument('--embedding_size', type=int, default=100, help='Dimension of the word embedding') @@ -250,8 +286,8 @@ def get_args(): help='report interval') parser.add_argument('--save_dir', type=str, default='out_dir', help='directory path to save the final model and training log') - parser.add_argument('--gpu', type=int, default=None, - help='id of the gpu to use. Set it to empty means to use cpu.') + parser.add_argument('--gpu', type=str, default=None, + help='Coma-separated ids of the gpu to use. Empty means to use cpu.') args = parser.parse_args() return args @@ -262,10 +298,30 @@ def get_args(): print(args) logging_config(args.save_dir) - train_dataset, train_dataloader = get_data(is_train=True, options=args) - word_vocab, char_vocab = get_vocabs(train_dataset, options=args) - - net = BiDAFModel(word_vocab, char_vocab, args, prefix="bidaf") - net.initialize(init.Xavier(magnitude=2.24)) - - run_training(net, train_dataloader, args) + if args.preprocess: + if not args.preprocessed_dataset_path: + logging.error("Preprocessed_data_path attribute is not provided") + exit(1) + + dataset = SQuAD(segment='train') + vocab_provider = VocabProvider(dataset) + transformed_dataset = transform_dataset(dataset, vocab_provider, options=args) + save_transformed_dataset(transformed_dataset, args) + exit(0) + + if args.train: + dataset = SQuAD(segment='train') + vocab_provider = VocabProvider(dataset) + mapper = QuestionIdMapper(dataset) + transformed_dataset = load_transformed_dataset(args) if args.preprocessed_dataset_path \ + else transform_dataset(dataset, vocab_provider, options=args) + + train_dataset, train_dataloader = get_record_per_answer_span(transformed_dataset, args) + word_vocab, char_vocab = get_vocabs(vocab_provider, options=args) + ctx = get_context(args) + + evaluator = PerformanceEvaluator(transformed_dataset, dataset._read_data(), mapper) + net = BiDAFModel(word_vocab, char_vocab, args, prefix="bidaf") + net.initialize(init.Xavier(magnitude=2.24), ctx=ctx) + + run_training(net, train_dataloader, evaluator, ctx, options=args) From 00b051d90d3453985a6fabd714b429bd35f17814 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Fri, 31 Aug 2018 17:14:34 -0700 Subject: [PATCH 05/43] Add save params + support for uneven data splits --- .../performance_evaluator.py | 10 +++--- .../train_question_answering.py | 36 +++++++++++++++---- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/scripts/question_answering/performance_evaluator.py b/scripts/question_answering/performance_evaluator.py index 9a02323dcc..55a613cd75 100644 --- a/scripts/question_answering/performance_evaluator.py +++ b/scripts/question_answering/performance_evaluator.py @@ -54,11 +54,11 @@ def evaluate_performance(self, net, ctx, options): for i, data in enumerate(eval_dataloader): record_index, q_words, ctx_words, q_chars, ctx_chars = data - record_index = gluon.utils.split_and_load(record_index, ctx) - q_words = gluon.utils.split_and_load(q_words, ctx) - ctx_words = gluon.utils.split_and_load(ctx_words, ctx) - q_chars = gluon.utils.split_and_load(q_chars, ctx) - ctx_chars = gluon.utils.split_and_load(ctx_chars, ctx) + record_index = gluon.utils.split_and_load(record_index, ctx, even_split=False) + q_words = gluon.utils.split_and_load(q_words, ctx, even_split=False) + ctx_words = gluon.utils.split_and_load(ctx_words, ctx, even_split=False) + q_chars = gluon.utils.split_and_load(q_chars, ctx, even_split=False) + ctx_chars = gluon.utils.split_and_load(ctx_chars, ctx, even_split=False) for ri, qw, cw, qc, cc in zip(record_index, q_words, ctx_words, q_chars, ctx_chars): out, _, _ = net((ri, qw, cw, qc, cc)) diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index b3881bc002..276e64cc43 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -16,6 +16,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import os + import logging import pickle @@ -184,12 +186,12 @@ def run_training(net, dataloader, evaluator, ctx, options): e_start = time() record_index, q_words, ctx_words, q_chars, ctx_chars = data - record_index = gluon.utils.split_and_load(record_index, ctx) - q_words = gluon.utils.split_and_load(q_words, ctx) - ctx_words = gluon.utils.split_and_load(ctx_words, ctx) - q_chars = gluon.utils.split_and_load(q_chars, ctx) - ctx_chars = gluon.utils.split_and_load(ctx_chars, ctx) - label = gluon.utils.split_and_load(label, ctx) + record_index = gluon.utils.split_and_load(record_index, ctx, even_split=False) + q_words = gluon.utils.split_and_load(q_words, ctx, even_split=False) + ctx_words = gluon.utils.split_and_load(ctx_words, ctx, even_split=False) + q_chars = gluon.utils.split_and_load(q_chars, ctx, even_split=False) + ctx_chars = gluon.utils.split_and_load(ctx_chars, ctx, even_split=False) + label = gluon.utils.split_and_load(label, ctx, even_split=False) # Wait for completion of previous iteration to avoid unnecessary memory allocation mx.nd.waitall() @@ -223,9 +225,31 @@ def run_training(net, dataloader, evaluator, ctx, options): .format(e, avg_loss_scalar, options.batch_size, trainer.learning_rate, epoch_time, eval_results)) + save_model_parameters(e, options) + print("Training time {:6.2f} seconds".format(time() - train_start)) +def save_model_parameters(net, epoch, options): + """Save parameters of the trained model + + Parameters + ---------- + net : `Block` + Model with trained parameters + epoch : `int` + Number of epoch + options : `Namespace` + Saving arguments + """ + + if not os.path.exists(options.save_dir): + os.mkdir(options.save_dir) + + save_path = os.path.join(args.save_dir, 'epoch{:d}.params'.format(epoch)) + net.save_parameters(save_path) + + def save_transformed_dataset(dataset, options): """Save processed dataset into a file. From 29d8b1e8b70b870ea6f782f52e1560d34d98f967 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Tue, 4 Sep 2018 10:18:02 -0700 Subject: [PATCH 06/43] Showcase last batch fails to be processed --- scripts/question_answering/train_question_answering.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index 276e64cc43..232d7f8f3f 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -101,7 +101,7 @@ def get_record_per_answer_span(processed_dataset, options): loadable_data = ArrayDataset(data_no_label, labels) dataloader = DataLoader(loadable_data, batch_size=options.batch_size, shuffle=True, - last_batch='discard') + last_batch='keep') return loadable_data, dataloader @@ -246,7 +246,7 @@ def save_model_parameters(net, epoch, options): if not os.path.exists(options.save_dir): os.mkdir(options.save_dir) - save_path = os.path.join(args.save_dir, 'epoch{:d}.params'.format(epoch)) + save_path = os.path.join(options.save_dir, 'epoch{:d}.params'.format(epoch)) net.save_parameters(save_path) @@ -274,6 +274,7 @@ def load_transformed_dataset(options): processed_dataset = pickle.load(open(options.preprocessed_dataset_path, "rb")) return processed_dataset + def get_args(): """Get console arguments """ From 51f7d2a94fdb8bb635f7924cf3ac7cfd01ac9b34 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Tue, 4 Sep 2018 13:46:09 -0700 Subject: [PATCH 07/43] Use correct attention layer --- .../question_answering/question_answering.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/scripts/question_answering/question_answering.py b/scripts/question_answering/question_answering.py index 673fb8ea2e..4225a57b4b 100644 --- a/scripts/question_answering/question_answering.py +++ b/scripts/question_answering/question_answering.py @@ -18,9 +18,7 @@ # under the License. """BiDAF model blocks""" -from mxnet.gluon.nn import HybridSequential - -from scripts.question_answering.attention_flow import AttentionFlow +from scripts.question_answering.bidaf import BidirectionalAttentionFlow from scripts.question_answering.similarity_function import DotProductSimilarity __all__ = ['BiDAFEmbedding', 'BiDAFModelingLayer', 'BiDAFOutputLayer', 'BiDAFModel'] @@ -218,6 +216,7 @@ class BiDAFModel(Block): def __init__(self, word_vocab, char_vocab, options, prefix=None, params=None): super().__init__(prefix=prefix, params=params) + self._options = options with self.name_scope(): self._ctx_embedding = BiDAFEmbedding(word_vocab, char_vocab, @@ -230,7 +229,7 @@ def __init__(self, word_vocab, char_vocab, options, prefix=None, params=None): options.highway_num_layers, options.embedding_size, prefix="question_embedding") - self._attention_layer = AttentionFlow(DotProductSimilarity()) + self._attention_layer = BidirectionalAttentionFlow(DotProductSimilarity()) self._modeling_layer = BiDAFModelingLayer(input_dim=options.embedding_size, nlayers=options.modeling_num_layers, dropout=options.dropout) @@ -244,7 +243,17 @@ def forward(self, x, ctx_embedding_states=None, q_embedding_states=None, *args): q_embedding_output, q_embedding_state = self._q_embedding([x[1], x[3]], q_embedding_states) - attention_layer_output = self._attention_layer(ctx_embedding_output, q_embedding_output) + q_mask = x[1] != 0 + ctx_mask = x[2] != 0 + + attention_layer_output = self._attention_layer(ctx_embedding_output, + q_embedding_output, + q_mask, + ctx_mask, + self._options.batch_size, + self._options.ctx_max_len, + self._options.embedding_size) + modeling_layer_output = self._modeling_layer(attention_layer_output) output = self._output_layer(attention_layer_output, modeling_layer_output) From 7cbd563d336a7c7698fe97d15e1add0d107225bb Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Wed, 5 Sep 2018 17:29:09 -0700 Subject: [PATCH 08/43] Multiple bug fixes in Attention flow layer --- scripts/question_answering/bidaf.py | 30 ++++++++++--------- .../question_answering/question_answering.py | 7 +++++ scripts/question_answering/utils.py | 18 +++++------ 3 files changed, 32 insertions(+), 23 deletions(-) diff --git a/scripts/question_answering/bidaf.py b/scripts/question_answering/bidaf.py index cb2a3df132..33520c15df 100644 --- a/scripts/question_answering/bidaf.py +++ b/scripts/question_answering/bidaf.py @@ -48,27 +48,29 @@ def hybrid_forward(self, F, encoded_passage, encoded_question, # Shape: (batch_size, passage_length, question_length) passage_question_attention = last_dim_softmax(passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) - passage_question_vectors = weighted_sum(encoded_question, passage_question_attention) + passage_question_vectors = weighted_sum(F, encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. - masked_similarity = replace_masked_values(passage_question_similarity, - question_mask.expand_dims(1), - -1e7) + masked_similarity = passage_question_similarity if question_mask is None else \ + replace_masked_values(passage_question_similarity, + question_mask.expand_dims(1), + -1e7) # Shape: (batch_size, passage_length) - question_passage_similarity = masked_similarity.max(axis=-1)[0] + question_passage_similarity = masked_similarity.max(axis=-1) # Shape: (batch_size, passage_length) question_passage_attention = masked_softmax(question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) - question_passage_vector = weighted_sum(encoded_passage, question_passage_attention) + question_passage_vector = weighted_sum(F, encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) - tiled_question_passage_vector = question_passage_vector.expand_dims(1).expand(batch_size, - passage_length, - encoding_dim) + tiled_question_passage_vector = question_passage_vector.expand_dims(1) # Shape: (batch_size, passage_length, encoding_dim * 4) - final_merged_passage = F.cat([encoded_passage, - passage_question_vectors, - encoded_passage * passage_question_vectors, - encoded_passage * tiled_question_passage_vector], - dim=-1) + final_merged_passage = F.concat(encoded_passage, + passage_question_vectors, + encoded_passage * passage_question_vectors, + F.broadcast_mul(encoded_passage, + tiled_question_passage_vector), + dim=-1) + + return final_merged_passage diff --git a/scripts/question_answering/question_answering.py b/scripts/question_answering/question_answering.py index 4225a57b4b..5d0658673a 100644 --- a/scripts/question_answering/question_answering.py +++ b/scripts/question_answering/question_answering.py @@ -243,6 +243,11 @@ def forward(self, x, ctx_embedding_states=None, q_embedding_states=None, *args): q_embedding_output, q_embedding_state = self._q_embedding([x[1], x[3]], q_embedding_states) + # attention layer expect batch_size x seq_length x channels + ctx_embedding_output = nd.transpose(ctx_embedding_output, axes=(1, 0, 2)) + q_embedding_output = nd.transpose(q_embedding_output, axes=(1, 0, 2)) + + # Both masks can be None q_mask = x[1] != 0 ctx_mask = x[2] != 0 @@ -254,6 +259,8 @@ def forward(self, x, ctx_embedding_states=None, q_embedding_states=None, *args): self._options.ctx_max_len, self._options.embedding_size) + # modeling layer expects seq_length x batch_size x channels + attention_layer_output = nd.transpose(attention_layer_output, axes=(1, 0, 2)) modeling_layer_output = self._modeling_layer(attention_layer_output) output = self._output_layer(attention_layer_output, modeling_layer_output) diff --git a/scripts/question_answering/utils.py b/scripts/question_answering/utils.py index e9ce838599..eb825ff377 100644 --- a/scripts/question_answering/utils.py +++ b/scripts/question_answering/utils.py @@ -216,9 +216,9 @@ def _last_dimension_applicator(function_to_apply, tensor_shape = tensor.shape reshaped_tensor = tensor.reshape(-1, tensor.shape[-1]) if mask is not None: - while mask.shape[0] < tensor.shape[0]: + while len(mask.shape) < len(tensor.shape): mask = mask.expand_dims(1) - mask = mask.broadcast_to(tensor).contiguous().float() + mask = mask.broadcast_to(shape=tensor.shape) mask = mask.reshape(-1, mask.shape[-1]) reshaped_result = function_to_apply(reshaped_tensor, mask) return reshaped_result.reshape(*tensor_shape) @@ -242,7 +242,7 @@ def last_dim_log_softmax(tensor, mask): return _last_dimension_applicator(masked_log_softmax, tensor, mask) -def weighted_sum(matrix, attention): +def weighted_sum(F, matrix, attention): """ Takes a matrix of vectors and a set of weights over the rows in the matrix (which we call an "attention" vector), and returns a weighted sum of the rows in the matrix. This is the typical @@ -266,13 +266,13 @@ def weighted_sum(matrix, attention): ``(batch_size, num_documents, num_queries, embedding_dim)`` respectively. """ - if attention.shape[0] == 2 and matrix.shape[0] == 3: - return attention.expand_dims(1).batch_dot(matrix).squeeze(1) - if attention.shape[0] == 3 and matrix.shape[0] == 3: - return attention.batch_dot(matrix) - if matrix.shape[0] - 1 < attention.shape[0]: + if len(attention.shape) == 2 and len(matrix.shape) == 3: + return F.squeeze(F.batch_dot(attention.expand_dims(1), matrix), axis=1) + if len(attention.shape) == 3 and len(matrix.shape) == 3: + return F.batch_dot(attention, matrix) + if len(matrix.shape) - 1 < len(attention.shape): expanded_size = list(matrix.shape) - for i in range(attention.shape[0] - matrix.shape[0] + 1): + for i in range(len(attention.shape) - len(matrix.shape) + 1): matrix = matrix.expand_dims(1) expanded_size.insert(i + 1, attention.shape[i + 1]) matrix = matrix.broadcast_to(*expanded_size) From ca9c948189639939f5d13cb09a549b6b47135091 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Wed, 5 Sep 2018 17:44:18 -0700 Subject: [PATCH 09/43] kvstore set to local to prevent malloc exception --- scripts/question_answering/train_question_answering.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index 232d7f8f3f..d71d311fb6 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -170,7 +170,8 @@ def run_training(net, dataloader, evaluator, ctx, options): Training arguments """ - trainer = Trainer(net.collect_params(), args.optimizer, {'learning_rate': options.lr}) + trainer = Trainer(net.collect_params(), args.optimizer, + {'learning_rate': options.lr}, kvstore="local") loss_function = SoftmaxCrossEntropyLoss() train_start = time() From 1f39fb12c5b4082a17d9f18304d5c6612cf15488 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Tue, 11 Sep 2018 15:57:09 -0700 Subject: [PATCH 10/43] Fix to get 1 epoch on 4 gpu ~1.8 hours --- scripts/question_answering/bidaf.py | 2 +- .../performance_evaluator.py | 9 +++- .../question_answering/question_answering.py | 43 +++++++++++++------ .../train_question_answering.py | 31 ++++++++++--- 4 files changed, 62 insertions(+), 23 deletions(-) diff --git a/scripts/question_answering/bidaf.py b/scripts/question_answering/bidaf.py index 33520c15df..a3c4163104 100644 --- a/scripts/question_answering/bidaf.py +++ b/scripts/question_answering/bidaf.py @@ -10,7 +10,7 @@ # # http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, +# Unless required by applicable law or agreed to in writinConvolutionalEncoderg, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the diff --git a/scripts/question_answering/performance_evaluator.py b/scripts/question_answering/performance_evaluator.py index 55a613cd75..c3d0bed6bc 100644 --- a/scripts/question_answering/performance_evaluator.py +++ b/scripts/question_answering/performance_evaluator.py @@ -54,6 +54,13 @@ def evaluate_performance(self, net, ctx, options): for i, data in enumerate(eval_dataloader): record_index, q_words, ctx_words, q_chars, ctx_chars = data + + record_index = record_index.astype(options.precision) + q_words = q_words.astype(options.precision) + ctx_words = ctx_words.astype(options.precision) + q_chars = q_chars.astype(options.precision) + ctx_chars = ctx_chars.astype(options.precision) + record_index = gluon.utils.split_and_load(record_index, ctx, even_split=False) q_words = gluon.utils.split_and_load(q_words, ctx, even_split=False) ctx_words = gluon.utils.split_and_load(ctx_words, ctx, even_split=False) @@ -95,7 +102,7 @@ def get_text_result(self, idx, answer_span): start, end = answer_span if start > end: - return None + return '' question_id = self._mapper.idx_to_question_id[idx] context = self._mapper.question_id_to_context[question_id] diff --git a/scripts/question_answering/question_answering.py b/scripts/question_answering/question_answering.py index 5d0658673a..df33a672d1 100644 --- a/scripts/question_answering/question_answering.py +++ b/scripts/question_answering/question_answering.py @@ -38,9 +38,10 @@ class BiDAFEmbedding(Block): 2. Tensor of characters: batch_size x words_per_question/context x chars_per_word """ def __init__(self, word_vocab, char_vocab, contextual_embedding_nlayers=2, highway_nlayers=2, - embedding_size=100, prefix=None, params=None): + embedding_size=100, precision='float32', prefix=None, params=None): super(BiDAFEmbedding, self).__init__(prefix=prefix, params=params) + self._precision = precision self._char_dense_embedding = nn.Embedding(input_dim=len(char_vocab), output_dim=8) self._char_conv_embedding = ConvolutionalEncoder( embed_size=8, @@ -79,9 +80,8 @@ def forward(self, x, contextual_embedding_state=None): # pylint: disable=argume # Step 3. Iterate over tokens of each batch and apply convolutional encoder # As a result of a single iteration, we get token embedding for every batch - token_list = [] - for token_of_all_batches in char_level_data: - token_list.append(self._char_conv_embedding(token_of_all_batches)) + token_list = [self._char_conv_embedding(token_of_all_batches) + for token_of_all_batches in char_level_data] # Step 4. Concat all tokens embeddings to create a single tensor. char_embedded = nd.concat(*token_list, dim=0) @@ -96,7 +96,9 @@ def forward(self, x, contextual_embedding_state=None): # pylint: disable=argume # Create starting state if necessary contextual_embedding_state = \ - self._contextual_embedding.begin_state(batch_size, ctx=highway_output.context) \ + self._contextual_embedding.begin_state(batch_size, + ctx=highway_output.context, + dtype=self._precision) \ if contextual_embedding_state is None else contextual_embedding_state # Pass through contextual embedding, which is just bi-LSTM @@ -131,14 +133,18 @@ class BiDAFModelingLayer(Block): Shared Parameters for this `Block`. """ def __init__(self, input_dim=100, nlayers=2, biflag=True, - dropout=0.2, prefix=None, params=None): + dropout=0.2, precision='float32', prefix=None, params=None): super(BiDAFModelingLayer, self).__init__(prefix=prefix, params=params) + self._precision = precision self._modeling_layer = LSTM(hidden_size=input_dim, num_layers=nlayers, dropout=dropout, bidirectional=biflag) def forward(self, x): # pylint: disable=arguments-differ - out = self._modeling_layer(x) + batch_size = x.shape[1] + + state = self._modeling_layer.begin_state(batch_size, ctx=x.context, dtype=self._precision) + out, _ = self._modeling_layer(x, state) return out @@ -158,7 +164,7 @@ class BiDAFOutputLayer(Block): ---------- span_start_input_dim : `int`, default 100 The number of features in the hidden state h of LSTM - units : `int`, default 10 * ``span_start_input_dim`` + units : `int`, default 4 * ``span_start_input_dim`` Number of hidden units of `Dense` layer nlayers : `int`, default 1 Number of recurrent layers. @@ -173,22 +179,26 @@ class BiDAFOutputLayer(Block): Shared Parameters for this `Block`. """ def __init__(self, span_start_input_dim=100, units=None, nlayers=1, biflag=True, - dropout=0.2, prefix=None, params=None): + dropout=0.2, precision='float32', prefix=None, params=None): super(BiDAFOutputLayer, self).__init__(prefix=prefix, params=params) - units = 10 * span_start_input_dim if units is None else units + units = 4 * span_start_input_dim if units is None else units + self._precision = precision self._start_index_dense = nn.Dense(units=units) self._end_index_lstm = LSTM(hidden_size=span_start_input_dim, num_layers=nlayers, dropout=dropout, bidirectional=biflag) self._end_index_dense = nn.Dense(units=units) def forward(self, x, m): # pylint: disable=arguments-differ + batch_size = x.shape[1] + # setting batch size as the first dimension start_index_input = nd.transpose(nd.concat(x, m, dim=2), axes=(1, 0, 2)) start_index_dense_output = self._start_index_dense(start_index_input) - end_index_input_part = self._end_index_lstm(m) + state = self._end_index_lstm.begin_state(batch_size, ctx=x.context, dtype=self._precision) + end_index_input_part, _ = self._end_index_lstm(m, state) end_index_input = nd.transpose(nd.concat(x, end_index_input_part, dim=2), axes=(1, 0, 2)) @@ -223,19 +233,23 @@ def __init__(self, word_vocab, char_vocab, options, prefix=None, params=None): options.ctx_embedding_num_layers, options.highway_num_layers, options.embedding_size, + precision=options.precision, prefix="context_embedding") self._q_embedding = BiDAFEmbedding(word_vocab, char_vocab, options.ctx_embedding_num_layers, options.highway_num_layers, options.embedding_size, + precision=options.precision, prefix="question_embedding") self._attention_layer = BidirectionalAttentionFlow(DotProductSimilarity()) self._modeling_layer = BiDAFModelingLayer(input_dim=options.embedding_size, nlayers=options.modeling_num_layers, - dropout=options.dropout) + dropout=options.dropout, + precision=options.precision) self._output_layer = BiDAFOutputLayer(span_start_input_dim=options.embedding_size, nlayers=options.output_num_layers, - dropout=options.dropout) + dropout=options.dropout, + precision=options.precision) def forward(self, x, ctx_embedding_states=None, q_embedding_states=None, *args): ctx_embedding_output, ctx_embedding_state = self._ctx_embedding([x[2], x[4]], @@ -258,10 +272,11 @@ def forward(self, x, ctx_embedding_states=None, q_embedding_states=None, *args): self._options.batch_size, self._options.ctx_max_len, self._options.embedding_size) - # modeling layer expects seq_length x batch_size x channels attention_layer_output = nd.transpose(attention_layer_output, axes=(1, 0, 2)) + modeling_layer_output = self._modeling_layer(attention_layer_output) + output = self._output_layer(attention_layer_output, modeling_layer_output) return output, ctx_embedding_state, q_embedding_state diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index d71d311fb6..e9a4d91f90 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -16,6 +16,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import multiprocessing import os import logging @@ -171,11 +172,12 @@ def run_training(net, dataloader, evaluator, ctx, options): """ trainer = Trainer(net.collect_params(), args.optimizer, - {'learning_rate': options.lr}, kvstore="local") + {'learning_rate': options.lr}, + kvstore="device") loss_function = SoftmaxCrossEntropyLoss() train_start = time() - avg_loss = mx.nd.zeros((1,), ctx=ctx[0]) + avg_loss = mx.nd.zeros((1,), ctx=ctx[0], dtype=options.precision) print("Starting training...") for e in range(args.epochs): @@ -187,6 +189,14 @@ def run_training(net, dataloader, evaluator, ctx, options): e_start = time() record_index, q_words, ctx_words, q_chars, ctx_chars = data + + record_index = record_index.astype(options.precision) + q_words = q_words.astype(options.precision) + ctx_words = ctx_words.astype(options.precision) + q_chars = q_chars.astype(options.precision) + ctx_chars = ctx_chars.astype(options.precision) + label = label.astype(options.precision) + record_index = gluon.utils.split_and_load(record_index, ctx, even_split=False) q_words = gluon.utils.split_and_load(q_words, ctx, even_split=False) ctx_words = gluon.utils.split_and_load(ctx_words, ctx, even_split=False) @@ -213,7 +223,11 @@ def run_training(net, dataloader, evaluator, ctx, options): for l in losses: avg_loss += l.mean().as_in_context(avg_loss.context) - eval_results = evaluator.evaluate_performance(net, ctx, options) + mx.nd.waitall() + print("Start evaluate performance") + #eval_results = evaluator.evaluate_performance(net, ctx, options) + eval_results = {} + print("End evaluate performance") avg_loss /= (i * len(ctx)) @@ -226,7 +240,7 @@ def run_training(net, dataloader, evaluator, ctx, options): .format(e, avg_loss_scalar, options.batch_size, trainer.learning_rate, epoch_time, eval_results)) - save_model_parameters(e, options) + save_model_parameters(net, e, options) print("Training time {:6.2f} seconds".format(time() - train_start)) @@ -299,9 +313,7 @@ def get_args(): help='Number of layers in Output layer of BiDAF') parser.add_argument('--batch_size', type=int, default=128, help='Batch size') parser.add_argument('--ctx_max_len', type=int, default=400, help='Maximum length of a context') - # TODO: Question max length in the paper is 30. Had to set it 400 to make dot_product - # similarity work - parser.add_argument('--q_max_len', type=int, default=400, help='Maximum length of a question') + parser.add_argument('--q_max_len', type=int, default=30, help='Maximum length of a question') parser.add_argument('--word_max_len', type=int, default=16, help='Maximum characters in a word') parser.add_argument('--optimizer', type=str, default='adam', help='optimization algorithm') parser.add_argument('--lr', type=float, default=1E-3, help='Initial learning rate') @@ -314,6 +326,10 @@ def get_args(): help='directory path to save the final model and training log') parser.add_argument('--gpu', type=str, default=None, help='Coma-separated ids of the gpu to use. Empty means to use cpu.') + parser.add_argument('--precision', type=str, default='float32', choices=['float16', 'float32'], + help='Use float16 or float32 precision') + #parser.add_argument('--use_multiprecision_in_optimizer', type=bool, default=False, + # help='When using float16, shall optimizer use multiprecision.') args = parser.parse_args() return args @@ -349,5 +365,6 @@ def get_args(): evaluator = PerformanceEvaluator(transformed_dataset, dataset._read_data(), mapper) net = BiDAFModel(word_vocab, char_vocab, args, prefix="bidaf") net.initialize(init.Xavier(magnitude=2.24), ctx=ctx) + net.cast(args.precision) run_training(net, train_dataloader, evaluator, ctx, options=args) From 8b14acd310eb029059fbe9001304c84351a51cef Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Tue, 18 Sep 2018 15:12:28 -0700 Subject: [PATCH 11/43] Make evaluation faster --- .../performance_evaluator.py | 9 +++ .../train_question_answering.py | 72 +++++++++++++++---- 2 files changed, 69 insertions(+), 12 deletions(-) diff --git a/scripts/question_answering/performance_evaluator.py b/scripts/question_answering/performance_evaluator.py index c3d0bed6bc..944d0a810b 100644 --- a/scripts/question_answering/performance_evaluator.py +++ b/scripts/question_answering/performance_evaluator.py @@ -67,8 +67,13 @@ def evaluate_performance(self, net, ctx, options): q_chars = gluon.utils.split_and_load(q_chars, ctx, even_split=False) ctx_chars = gluon.utils.split_and_load(ctx_chars, ctx, even_split=False) + outs = [] + for ri, qw, cw, qc, cc in zip(record_index, q_words, ctx_words, q_chars, ctx_chars): out, _, _ = net((ri, qw, cw, qc, cc)) + outs.append(out) + + for out in outs: out_per_index = out.transpose(axes=(1, 0, 2)) start_indices = PerformanceEvaluator._get_index(out_per_index[0]) end_indices = PerformanceEvaluator._get_index(out_per_index[1]) @@ -80,6 +85,10 @@ def evaluate_performance(self, net, ctx, options): end = int(end.asscalar()) pred[self._mapper.idx_to_question_id[idx]] = self.get_text_result(idx, (start, end)) + if options.save_prediction_path: + with open(options.save_prediction_path, "w") as f: + for item in pred.items(): + f.write("QId {}, Answer: {}\n".format(item[0], item[1])) return evaluate(self._json_data['data'], pred) diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index e9a4d91f90..8636fa21b5 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -18,6 +18,7 @@ # under the License. import multiprocessing import os +from os.path import isfile import logging import pickle @@ -265,28 +266,28 @@ def save_model_parameters(net, epoch, options): net.save_parameters(save_path) -def save_transformed_dataset(dataset, options): +def save_transformed_dataset(dataset, path): """Save processed dataset into a file. Parameters ---------- dataset : `Dataset` Dataset to save - options : `Namespace` - Saving arguments + path : `str` + Saving path """ - pickle.dump(dataset, open(options.preprocessed_dataset_path, "wb")) + pickle.dump(dataset, open(path, "wb")) -def load_transformed_dataset(options): +def load_transformed_dataset(path): """Loads already preprocessed dataset from disk Parameters ---------- - options : `Namespace` - Loading arguments + path : `str` + Loading path """ - processed_dataset = pickle.load(open(options.preprocessed_dataset_path, "rb")) + processed_dataset = pickle.load(open(path, "rb")) return processed_dataset @@ -295,9 +296,13 @@ def get_args(): """ parser = argparse.ArgumentParser(description='Question Answering example using BiDAF & SQuAD') parser.add_argument('--preprocess', type=bool, default=False, help='Preprocess dataset only') - parser.add_argument('--train', type=bool, default=True, help='Run training') + parser.add_argument('--train', type=bool, default=False, help='Run training') + parser.add_argument('--evaluate', type=bool, default=False, help='Run evaluation on dev dataset') parser.add_argument('--preprocessed_dataset_path', type=str, default="preprocessed_dataset.p", help='Path to preprocessed dataset') + parser.add_argument('--preprocessed_val_dataset_path', type=str, + default="preprocessed_val_dataset.p", help='Path to preprocessed ' + 'validation dataset') parser.add_argument('--epochs', type=int, default=40, help='Upper epoch limit') parser.add_argument('--embedding_size', type=int, default=100, help='Dimension of the word embedding') @@ -328,6 +333,8 @@ def get_args(): help='Coma-separated ids of the gpu to use. Empty means to use cpu.') parser.add_argument('--precision', type=str, default='float32', choices=['float16', 'float32'], help='Use float16 or float32 precision') + parser.add_argument('--save_prediction_path', type=str, default='', + help='Path to save predictions') #parser.add_argument('--use_multiprecision_in_optimizer', type=bool, default=False, # help='When using float16, shall optimizer use multiprecision.') @@ -345,18 +352,26 @@ def get_args(): logging.error("Preprocessed_data_path attribute is not provided") exit(1) + print("Running in preprocessing mode") + dataset = SQuAD(segment='train') vocab_provider = VocabProvider(dataset) transformed_dataset = transform_dataset(dataset, vocab_provider, options=args) - save_transformed_dataset(transformed_dataset, args) + save_transformed_dataset(transformed_dataset, args.preprocessed_dataset_path) exit(0) if args.train: + print("Running in training mode") + dataset = SQuAD(segment='train') vocab_provider = VocabProvider(dataset) mapper = QuestionIdMapper(dataset) - transformed_dataset = load_transformed_dataset(args) if args.preprocessed_dataset_path \ - else transform_dataset(dataset, vocab_provider, options=args) + + if args.preprocessed_dataset_path and isfile(args.preprocessed_dataset_path): + transformed_dataset = load_transformed_dataset(args.preprocessed_dataset_path) + else: + transformed_dataset = transform_dataset(dataset, vocab_provider, options=args) + save_transformed_dataset(transformed_dataset, args.preprocessed_dataset_path) train_dataset, train_dataloader = get_record_per_answer_span(transformed_dataset, args) word_vocab, char_vocab = get_vocabs(vocab_provider, options=args) @@ -368,3 +383,36 @@ def get_args(): net.cast(args.precision) run_training(net, train_dataloader, evaluator, ctx, options=args) + + if args.evaluate: + print("Running in evaluation mode") + # we use training dataset to build vocabs + model_path = os.path.join(args.save_dir, 'epoch{:d}.params'.format(int(args.epochs) - 1)) + + train_dataset = SQuAD(segment='train') + vocab_provider = VocabProvider(train_dataset) + + dataset = SQuAD(segment='dev') + mapper = QuestionIdMapper(dataset) + + transformed_dataset = load_transformed_dataset(args.preprocessed_val_dataset_path) \ + if args.preprocessed_val_dataset_path and isfile(args.preprocessed_val_dataset_path) \ + else transform_dataset(dataset, vocab_provider, options=args) + + if args.preprocessed_val_dataset_path and isfile(args.preprocessed_val_dataset_path): + transformed_dataset = load_transformed_dataset(args.preprocessed_val_dataset_path) + else: + transformed_dataset = transform_dataset(dataset, vocab_provider, options=args) + save_transformed_dataset(transformed_dataset, args.preprocessed_val_dataset_path) + + val_dataset, val_dataloader = get_record_per_answer_span(transformed_dataset, args) + word_vocab, char_vocab = get_vocabs(vocab_provider, options=args) + ctx = get_context(args) + + evaluator = PerformanceEvaluator(transformed_dataset, dataset._read_data(), mapper) + net = BiDAFModel(word_vocab, char_vocab, args, prefix="bidaf") + net.load_parameters(model_path, ctx=ctx) + + result = evaluator.evaluate_performance(net, ctx, args) + print("Evaluation results on dev dataset: {}".format(result)) + From f8368f6250e3d0bc6b22b0612c3df2f502dddc36 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Thu, 20 Sep 2018 13:45:32 -0700 Subject: [PATCH 12/43] Update hyperparameters --- scripts/question_answering/train_question_answering.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index 8636fa21b5..2856e3b361 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -303,7 +303,7 @@ def get_args(): parser.add_argument('--preprocessed_val_dataset_path', type=str, default="preprocessed_val_dataset.p", help='Path to preprocessed ' 'validation dataset') - parser.add_argument('--epochs', type=int, default=40, help='Upper epoch limit') + parser.add_argument('--epochs', type=int, default=12, help='Upper epoch limit') parser.add_argument('--embedding_size', type=int, default=100, help='Dimension of the word embedding') parser.add_argument('--dropout', type=float, default=0.2, @@ -320,10 +320,8 @@ def get_args(): parser.add_argument('--ctx_max_len', type=int, default=400, help='Maximum length of a context') parser.add_argument('--q_max_len', type=int, default=30, help='Maximum length of a question') parser.add_argument('--word_max_len', type=int, default=16, help='Maximum characters in a word') - parser.add_argument('--optimizer', type=str, default='adam', help='optimization algorithm') - parser.add_argument('--lr', type=float, default=1E-3, help='Initial learning rate') - parser.add_argument('--lr_update_factor', type=float, default=0.5, - help='Learning rate decay factor') + parser.add_argument('--optimizer', type=str, default='adadelta', help='optimization algorithm') + parser.add_argument('--lr', type=float, default=0.5, help='Initial learning rate') parser.add_argument('--clip', type=float, default=5.0, help='gradient clipping') parser.add_argument('--log_interval', type=int, default=100, metavar='N', help='report interval') From 152bb4c95244dacfa6d45091df0b936a09930f74 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Mon, 24 Sep 2018 14:09:35 -0700 Subject: [PATCH 13/43] Float16 works, hybridization not --- scripts/question_answering/attention_flow.py | 23 +- scripts/question_answering/bidaf.py | 47 +++- .../question_answering/question_answering.py | 225 ++++++++++-------- .../question_answering/similarity_function.py | 2 +- .../train_question_answering.py | 58 +---- scripts/question_answering/utils.py | 107 +++++++-- scripts/tests/test_question_answering.py | 184 ++++++++++++-- 7 files changed, 438 insertions(+), 208 deletions(-) diff --git a/scripts/question_answering/attention_flow.py b/scripts/question_answering/attention_flow.py index 5154704772..4d55a5655c 100644 --- a/scripts/question_answering/attention_flow.py +++ b/scripts/question_answering/attention_flow.py @@ -46,19 +46,24 @@ class AttentionFlow(gluon.HybridBlock): similarity_function: ``SimilarityFunction``, optional (default=``DotProductSimilarity``) The similarity function to use when computing the attention. """ - def __init__(self, similarity_function, **kwargs): + def __init__(self, similarity_function, batch_size, passage_length, + question_length, embedding_size, **kwargs): super(AttentionFlow, self).__init__(**kwargs) self._similarity_function = similarity_function or DotProductSimilarity() + self._batch_size = batch_size + self._passage_length = passage_length + self._question_length = question_length + self._embedding_size = embedding_size def hybrid_forward(self, F, matrix_1, matrix_2): # pylint: disable=arguments-differ - tiled_matrix_1 = matrix_1.expand_dims(2).broadcast_to(shape=(matrix_1.shape[0], - matrix_1.shape[1], - matrix_2.shape[1], - matrix_1.shape[2])) - tiled_matrix_2 = matrix_2.expand_dims(1).broadcast_to(shape=(matrix_2.shape[0], - matrix_1.shape[1], - matrix_2.shape[1], - matrix_2.shape[2])) + tiled_matrix_1 = matrix_1.expand_dims(2).broadcast_to(shape=(self._batch_size, + self._passage_length, + self._question_length, + self._embedding_size)) + tiled_matrix_2 = matrix_2.expand_dims(1).broadcast_to(shape=(self._batch_size, + self._passage_length, + self._question_length, + self._embedding_size)) return self._similarity_function(tiled_matrix_1, tiled_matrix_2) diff --git a/scripts/question_answering/bidaf.py b/scripts/question_answering/bidaf.py index a3c4163104..0f32e1d165 100644 --- a/scripts/question_answering/bidaf.py +++ b/scripts/question_answering/bidaf.py @@ -32,36 +32,67 @@ class BidirectionalAttentionFlow(gluon.HybridBlock): """ def __init__(self, attention_similarity_function, + batch_size, + passage_length, + question_length, + encoding_dim, **kwargs): super(BidirectionalAttentionFlow, self).__init__(**kwargs) - self._matrix_attention = AttentionFlow(attention_similarity_function) + self._batch_size = batch_size + self._passage_length = passage_length + self._question_length = question_length + self._encoding_dim = encoding_dim + self._matrix_attention = AttentionFlow(attention_similarity_function, + batch_size, passage_length, question_length, + encoding_dim) - def hybrid_forward(self, F, encoded_passage, encoded_question, - question_mask, passage_mask, batch_size, passage_length, encoding_dim): + def hybrid_forward(self, F, encoded_passage, encoded_question, question_mask, passage_mask): # pylint: disable=arguments-differ """ """ # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) + passage_question_similarity_shape = (self._batch_size, self._passage_length, + self._question_length) + + question_mask_shape = (self._batch_size, self._question_length) # Shape: (batch_size, passage_length, question_length) - passage_question_attention = last_dim_softmax(passage_question_similarity, question_mask) + passage_question_attention = last_dim_softmax(F, + passage_question_similarity, + question_mask, + passage_question_similarity_shape, + question_mask_shape) # Shape: (batch_size, passage_length, encoding_dim) - passage_question_vectors = weighted_sum(F, encoded_question, passage_question_attention) + encoded_question_shape = (self._batch_size, self._question_length, self._encoding_dim) + passage_question_attention_shape = (self._batch_size, self._passage_length, + self._question_length) + passage_question_vectors = weighted_sum(F, encoded_question, passage_question_attention, + encoded_question_shape, + passage_question_attention_shape) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = passage_question_similarity if question_mask is None else \ - replace_masked_values(passage_question_similarity, + replace_masked_values(F, + passage_question_similarity, question_mask.expand_dims(1), -1e7) + # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max(axis=-1) + # Shape: (batch_size, passage_length) - question_passage_attention = masked_softmax(question_passage_similarity, passage_mask) + question_passage_attention = masked_softmax(F, question_passage_similarity, passage_mask) + # Shape: (batch_size, encoding_dim) - question_passage_vector = weighted_sum(F, encoded_passage, question_passage_attention) + encoded_passage_shape = (self._batch_size, self._passage_length, self._encoding_dim) + question_passage_attention_shape = (self._batch_size, self._passage_length) + question_passage_vector = weighted_sum(F, encoded_passage, question_passage_attention, + encoded_passage_shape, + question_passage_attention_shape) + # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.expand_dims(1) diff --git a/scripts/question_answering/question_answering.py b/scripts/question_answering/question_answering.py index df33a672d1..dd61472dfa 100644 --- a/scripts/question_answering/question_answering.py +++ b/scripts/question_answering/question_answering.py @@ -23,25 +23,30 @@ __all__ = ['BiDAFEmbedding', 'BiDAFModelingLayer', 'BiDAFOutputLayer', 'BiDAFModel'] -from mxnet import nd, init -from mxnet.gluon import Block +from mxnet import initializer +from mxnet.gluon import HybridBlock from mxnet.gluon import nn from mxnet.gluon.rnn import LSTM from gluonnlp.model import ConvolutionalEncoder, Highway -class BiDAFEmbedding(Block): +class BiDAFEmbedding(HybridBlock): """BiDAFEmbedding is a class describing embeddings that are separately applied to question and context of the datasource. Both question and context are passed in two NDArrays: 1. Matrix of words: batch_size x words_per_question/context 2. Tensor of characters: batch_size x words_per_question/context x chars_per_word """ - def __init__(self, word_vocab, char_vocab, contextual_embedding_nlayers=2, highway_nlayers=2, - embedding_size=100, precision='float32', prefix=None, params=None): + def __init__(self, batch_size, word_vocab, char_vocab, max_seq_len, + contextual_embedding_nlayers=2, highway_nlayers=2, embedding_size=100, + precision='float32', prefix=None, params=None): super(BiDAFEmbedding, self).__init__(prefix=prefix, params=params) + self._word_vocab = word_vocab + self._batch_size = batch_size + self._max_seq_len = max_seq_len self._precision = precision + self._embedding_size = embedding_size self._char_dense_embedding = nn.Embedding(input_dim=len(char_vocab), output_dim=8) self._char_conv_embedding = ConvolutionalEncoder( embed_size=8, @@ -52,20 +57,27 @@ def __init__(self, word_vocab, char_vocab, contextual_embedding_nlayers=2, highw output_size=None ) - self._word_embedding = nn.Embedding(input_dim=len(word_vocab), output_dim=embedding_size, - weight_initializer=init.Constant( - word_vocab.embedding.idx_to_vec)) + self._word_embedding = nn.Embedding(input_dim=len(word_vocab), output_dim=embedding_size) self._highway_network = Highway(2 * embedding_size, num_layers=highway_nlayers) self._contextual_embedding = LSTM(hidden_size=embedding_size, num_layers=contextual_embedding_nlayers, - bidirectional=True) + bidirectional=True, input_size=2 * embedding_size) - def forward(self, x, contextual_embedding_state=None): # pylint: disable=arguments-differ - batch_size = x[0].shape[0] + def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False, force_reinit=False): + super(BiDAFEmbedding, self).initialize(init, ctx, verbose, force_reinit) + self._word_embedding.weight.set_data(self._word_vocab.embedding.idx_to_vec) + + def begin_state(self): + state = self._contextual_embedding.begin_state(self._batch_size, + dtype=self._precision) + + return state + + def hybrid_forward(self, F, w, c, contextual_embedding_state, *args): # Changing shape from NTC to TNC as most MXNet blocks work with TNC format natively - word_level_data = nd.transpose(x[0], axes=(1, 0)) - char_level_data = nd.transpose(x[1], axes=(1, 0, 2)) + word_level_data = F.transpose(w, axes=(1, 0)) + char_level_data = F.transpose(c, axes=(1, 0, 2)) # Get word embeddings. Output is batch_size x seq_len x embedding size (100) word_embedded = self._word_embedding(word_level_data) @@ -76,39 +88,38 @@ def forward(self, x, contextual_embedding_state=None): # pylint: disable=argume # Step 2. Transpose to put seq_len first axis to later iterate over it # In that way we can get embedding per token of every batch - char_level_data = nd.transpose(char_level_data, axes=(0, 2, 1, 3)) + char_level_data = F.transpose(char_level_data, axes=(0, 2, 1, 3)) # Step 3. Iterate over tokens of each batch and apply convolutional encoder # As a result of a single iteration, we get token embedding for every batch - token_list = [self._char_conv_embedding(token_of_all_batches) - for token_of_all_batches in char_level_data] + def convolute(token_of_all_batches, _): + return self._char_conv_embedding(token_of_all_batches), [] + + token_list, _ = F.contrib.foreach(convolute, char_level_data, []) # Step 4. Concat all tokens embeddings to create a single tensor. - char_embedded = nd.concat(*token_list, dim=0) + char_embedded = F.concat(*token_list, dim=0) # Step 5. Reshape tensor to match dimensions of embedded words - char_embedded = char_embedded.reshape(shape=word_embedded.shape) + char_embedded = char_embedded.reshape(shape=(self._max_seq_len, + self._batch_size, + self._embedding_size)) # Concat embeddings, making channels size = 200 - highway_input = nd.concat(char_embedded, word_embedded, dim=2) + highway_input = F.concat(char_embedded, word_embedded, dim=2) + # Pass through highway, shape remains unchanged highway_output = self._highway_network(highway_input) - # Create starting state if necessary - contextual_embedding_state = \ - self._contextual_embedding.begin_state(batch_size, - ctx=highway_output.context, - dtype=self._precision) \ - if contextual_embedding_state is None else contextual_embedding_state + if contextual_embedding_state is None: + contextual_embedding_state = self.begin_state() - # Pass through contextual embedding, which is just bi-LSTM ce_output, ce_state = self._contextual_embedding(highway_output, contextual_embedding_state) - - return ce_output, ce_state + return ce_output -class BiDAFModelingLayer(Block): +class BiDAFModelingLayer(HybridBlock): """BiDAFModelingLayer implements modeling layer of BiDAF paper. It is used to scan over context produced by Attentional Flow Layer via 2 layer bi-LSTM. @@ -132,23 +143,29 @@ class BiDAFModelingLayer(Block): params : `ParameterDict` or `None` Shared Parameters for this `Block`. """ - def __init__(self, input_dim=100, nlayers=2, biflag=True, + def __init__(self, batch_size, input_dim=100, nlayers=2, biflag=True, dropout=0.2, precision='float32', prefix=None, params=None): super(BiDAFModelingLayer, self).__init__(prefix=prefix, params=params) + self._batch_size = batch_size self._precision = precision self._modeling_layer = LSTM(hidden_size=input_dim, num_layers=nlayers, dropout=dropout, - bidirectional=biflag) + bidirectional=biflag, input_size=800) + + def begin_state(self): + state = self._modeling_layer.begin_state(self._batch_size, + dtype=self._precision) + return state - def forward(self, x): # pylint: disable=arguments-differ - batch_size = x.shape[1] + def hybrid_forward(self, F, x, state, *args): + if state is None: + state = self.begin_state() - state = self._modeling_layer.begin_state(batch_size, ctx=x.context, dtype=self._precision) out, _ = self._modeling_layer(x, state) return out -class BiDAFOutputLayer(Block): +class BiDAFOutputLayer(HybridBlock): """ ``BiDAFOutputLayer`` produces the final prediction of an answer. The output is a tuple of start index and end index of the answer in the paragraph per each batch. @@ -178,28 +195,36 @@ class BiDAFOutputLayer(Block): params : `ParameterDict` or `None` Shared Parameters for this `Block`. """ - def __init__(self, span_start_input_dim=100, units=None, nlayers=1, biflag=True, + def __init__(self, batch_size, span_start_input_dim=100, units=None, nlayers=1, biflag=True, dropout=0.2, precision='float32', prefix=None, params=None): super(BiDAFOutputLayer, self).__init__(prefix=prefix, params=params) units = 4 * span_start_input_dim if units is None else units + self._batch_size = batch_size self._precision = precision self._start_index_dense = nn.Dense(units=units) self._end_index_lstm = LSTM(hidden_size=span_start_input_dim, - num_layers=nlayers, dropout=dropout, bidirectional=biflag) + num_layers=nlayers, dropout=dropout, bidirectional=biflag, + input_size=200) self._end_index_dense = nn.Dense(units=units) - def forward(self, x, m): # pylint: disable=arguments-differ - batch_size = x.shape[1] + def begin_state(self): + state = self._end_index_lstm.begin_state(self._batch_size, + dtype=self._precision) + return state + + def hybrid_forward(self, F, x, m, state, *args): # pylint: disable=arguments-differ # setting batch size as the first dimension - start_index_input = nd.transpose(nd.concat(x, m, dim=2), axes=(1, 0, 2)) + start_index_input = F.transpose(F.concat(x, m, dim=2), axes=(1, 0, 2)) start_index_dense_output = self._start_index_dense(start_index_input) - state = self._end_index_lstm.begin_state(batch_size, ctx=x.context, dtype=self._precision) + if state is None: + state = self.begin_state() + end_index_input_part, _ = self._end_index_lstm(m, state) - end_index_input = nd.transpose(nd.concat(x, end_index_input_part, dim=2), + end_index_input = F.transpose(F.concat(x, end_index_input_part, dim=2), axes=(1, 0, 2)) end_index_dense_output = self._end_index_dense(end_index_input) @@ -208,75 +233,89 @@ def forward(self, x, m): # pylint: disable=arguments-differ # Maybe should use autograd properties to check it # Will need to reuse it to actually make predictions # start_index_softmax_output = start_index_dense_output.softmax(axis=1) - # start_index = nd.argmax(start_index_softmax_output, axis=1) + # start_index = F.argmax(start_index_softmax_output, axis=1) # end_index_softmax_output = end_index_dense_output.softmax(axis=1) - # end_index = nd.argmax(end_index_softmax_output, axis=1) + # end_index = F.argmax(end_index_softmax_output, axis=1) # producing output in shape 2 x batch_size x units - output = nd.concat(nd.expand_dims(start_index_dense_output, axis=0), - nd.expand_dims(end_index_dense_output, axis=0), dim=0) + output = F.concat(F.expand_dims(start_index_dense_output, axis=0), + F.expand_dims(end_index_dense_output, axis=0), dim=0) # transposing it to batch_size x 2 x units - return nd.transpose(output, axes=(1, 0, 2)) + return F.transpose(output, axes=(1, 0, 2)) -class BiDAFModel(Block): +class BiDAFModel(HybridBlock): """Bidirectional attention flow model for Question answering """ - def __init__(self, word_vocab, char_vocab, options, prefix=None, params=None): super().__init__(prefix=prefix, params=params) self._options = options with self.name_scope(): - self._ctx_embedding = BiDAFEmbedding(word_vocab, char_vocab, - options.ctx_embedding_num_layers, - options.highway_num_layers, - options.embedding_size, - precision=options.precision, - prefix="context_embedding") - self._q_embedding = BiDAFEmbedding(word_vocab, char_vocab, - options.ctx_embedding_num_layers, - options.highway_num_layers, - options.embedding_size, - precision=options.precision, - prefix="question_embedding") - self._attention_layer = BidirectionalAttentionFlow(DotProductSimilarity()) - self._modeling_layer = BiDAFModelingLayer(input_dim=options.embedding_size, - nlayers=options.modeling_num_layers, - dropout=options.dropout, - precision=options.precision) - self._output_layer = BiDAFOutputLayer(span_start_input_dim=options.embedding_size, - nlayers=options.output_num_layers, - dropout=options.dropout, - precision=options.precision) - - def forward(self, x, ctx_embedding_states=None, q_embedding_states=None, *args): - ctx_embedding_output, ctx_embedding_state = self._ctx_embedding([x[2], x[4]], - ctx_embedding_states) - q_embedding_output, q_embedding_state = self._q_embedding([x[1], x[3]], - q_embedding_states) + self.ctx_embedding = BiDAFEmbedding(options.batch_size, + word_vocab, + char_vocab, + options.ctx_max_len, + options.ctx_embedding_num_layers, + options.highway_num_layers, + options.embedding_size, + precision=options.precision, + prefix="context_embedding") + self.q_embedding = BiDAFEmbedding(options.batch_size, + word_vocab, + char_vocab, + options.q_max_len, + options.ctx_embedding_num_layers, + options.highway_num_layers, + options.embedding_size, + precision=options.precision, + prefix="question_embedding") + + # we multiple embedding_size by 2 because we use bidirectional embedding + self.attention_layer = BidirectionalAttentionFlow(DotProductSimilarity(), + options.batch_size, + options.ctx_max_len, + options.q_max_len, + 2 * options.embedding_size) + self.modeling_layer = BiDAFModelingLayer(options.batch_size, + input_dim=options.embedding_size, + nlayers=options.modeling_num_layers, + dropout=options.dropout, + precision=options.precision) + self.output_layer = BiDAFOutputLayer(options.batch_size, + span_start_input_dim=options.embedding_size, + nlayers=options.output_num_layers, + dropout=options.dropout, + precision=options.precision) + + def hybrid_forward(self, F, ri, qw, cw, qc, cc, + ctx_embedding_states=None, + q_embedding_states=None, + modeling_layer_states=None, + output_layer_states=None, + *args): + ctx_embedding_output = self.ctx_embedding(cw, cc, ctx_embedding_states) + q_embedding_output = self.q_embedding(qw, qc, q_embedding_states) # attention layer expect batch_size x seq_length x channels - ctx_embedding_output = nd.transpose(ctx_embedding_output, axes=(1, 0, 2)) - q_embedding_output = nd.transpose(q_embedding_output, axes=(1, 0, 2)) + ctx_embedding_output = F.transpose(ctx_embedding_output, axes=(1, 0, 2)) + q_embedding_output = F.transpose(q_embedding_output, axes=(1, 0, 2)) # Both masks can be None - q_mask = x[1] != 0 - ctx_mask = x[2] != 0 - - attention_layer_output = self._attention_layer(ctx_embedding_output, - q_embedding_output, - q_mask, - ctx_mask, - self._options.batch_size, - self._options.ctx_max_len, - self._options.embedding_size) - # modeling layer expects seq_length x batch_size x channels - attention_layer_output = nd.transpose(attention_layer_output, axes=(1, 0, 2)) + q_mask = qw != 0 + ctx_mask = cw != 0 + + attention_layer_output = self.attention_layer(ctx_embedding_output, + q_embedding_output, + q_mask, + ctx_mask) + attention_layer_output = F.transpose(attention_layer_output, axes=(1, 0, 2)) - modeling_layer_output = self._modeling_layer(attention_layer_output) + # modeling layer expects seq_length x batch_size x channels + modeling_layer_output = self.modeling_layer(attention_layer_output, modeling_layer_states) - output = self._output_layer(attention_layer_output, modeling_layer_output) + output = self.output_layer(attention_layer_output, modeling_layer_output, + output_layer_states) - return output, ctx_embedding_state, q_embedding_state + return output diff --git a/scripts/question_answering/similarity_function.py b/scripts/question_answering/similarity_function.py index 0b52e24bf7..26616a6729 100644 --- a/scripts/question_answering/similarity_function.py +++ b/scripts/question_answering/similarity_function.py @@ -66,7 +66,7 @@ def hybrid_forward(self, F, array_1, array_2): result = (array_1 * array_2).sum(axis=-1) if self._scale_output: - result *= F.sqrt(array_1.shape[-1]) + result *= F.contrib.div_sqrt_dim(array_1) return result diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index 2856e3b361..ce99b05f5e 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -41,7 +41,7 @@ from scripts.question_answering.performance_evaluator import PerformanceEvaluator from scripts.question_answering.question_answering import * from scripts.question_answering.question_id_mapper import QuestionIdMapper -from scripts.question_answering.utils import logging_config +from scripts.question_answering.utils import logging_config, get_args np.random.seed(100) random.seed(100) @@ -212,7 +212,7 @@ def run_training(net, dataloader, evaluator, ctx, options): for ri, qw, cw, qc, cc, l in zip(record_index, q_words, ctx_words, q_chars, ctx_chars, label): with autograd.record(): - o, _, _ = net((ri, qw, cw, qc, cc)) + o, _, _ = net(ri, qw, cw, qc, cc) loss = loss_function(o, l) losses.append(loss) @@ -291,55 +291,6 @@ def load_transformed_dataset(path): return processed_dataset -def get_args(): - """Get console arguments - """ - parser = argparse.ArgumentParser(description='Question Answering example using BiDAF & SQuAD') - parser.add_argument('--preprocess', type=bool, default=False, help='Preprocess dataset only') - parser.add_argument('--train', type=bool, default=False, help='Run training') - parser.add_argument('--evaluate', type=bool, default=False, help='Run evaluation on dev dataset') - parser.add_argument('--preprocessed_dataset_path', type=str, - default="preprocessed_dataset.p", help='Path to preprocessed dataset') - parser.add_argument('--preprocessed_val_dataset_path', type=str, - default="preprocessed_val_dataset.p", help='Path to preprocessed ' - 'validation dataset') - parser.add_argument('--epochs', type=int, default=12, help='Upper epoch limit') - parser.add_argument('--embedding_size', type=int, default=100, - help='Dimension of the word embedding') - parser.add_argument('--dropout', type=float, default=0.2, - help='dropout applied to layers (0 = no dropout)') - parser.add_argument('--ctx_embedding_num_layers', type=int, default=2, - help='Number of layers in Contextual embedding layer of BiDAF') - parser.add_argument('--highway_num_layers', type=int, default=2, - help='Number of layers in Highway layer of BiDAF') - parser.add_argument('--modeling_num_layers', type=int, default=2, - help='Number of layers in Modeling layer of BiDAF') - parser.add_argument('--output_num_layers', type=int, default=1, - help='Number of layers in Output layer of BiDAF') - parser.add_argument('--batch_size', type=int, default=128, help='Batch size') - parser.add_argument('--ctx_max_len', type=int, default=400, help='Maximum length of a context') - parser.add_argument('--q_max_len', type=int, default=30, help='Maximum length of a question') - parser.add_argument('--word_max_len', type=int, default=16, help='Maximum characters in a word') - parser.add_argument('--optimizer', type=str, default='adadelta', help='optimization algorithm') - parser.add_argument('--lr', type=float, default=0.5, help='Initial learning rate') - parser.add_argument('--clip', type=float, default=5.0, help='gradient clipping') - parser.add_argument('--log_interval', type=int, default=100, metavar='N', - help='report interval') - parser.add_argument('--save_dir', type=str, default='out_dir', - help='directory path to save the final model and training log') - parser.add_argument('--gpu', type=str, default=None, - help='Coma-separated ids of the gpu to use. Empty means to use cpu.') - parser.add_argument('--precision', type=str, default='float32', choices=['float16', 'float32'], - help='Use float16 or float32 precision') - parser.add_argument('--save_prediction_path', type=str, default='', - help='Path to save predictions') - #parser.add_argument('--use_multiprecision_in_optimizer', type=bool, default=False, - # help='When using float16, shall optimizer use multiprecision.') - - args = parser.parse_args() - return args - - if __name__ == "__main__": args = get_args() print(args) @@ -379,6 +330,11 @@ def get_args(): net = BiDAFModel(word_vocab, char_vocab, args, prefix="bidaf") net.initialize(init.Xavier(magnitude=2.24), ctx=ctx) net.cast(args.precision) + #net._ctx_embedding.hybridize() + #net._q_embedding.hybridize() + net._attention_layer.hybridize() + net._modeling_layer.hybridize() + net._output_layer.hybridize() run_training(net, train_dataloader, evaluator, ctx, options=args) diff --git a/scripts/question_answering/utils.py b/scripts/question_answering/utils.py index eb825ff377..0ad53e5224 100644 --- a/scripts/question_answering/utils.py +++ b/scripts/question_answering/utils.py @@ -16,6 +16,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import argparse import os import inspect @@ -73,6 +74,55 @@ def logging_config(folder=None, name=None, level=logging.DEBUG, console_level=lo return folder +def get_args(): + """Get console arguments + """ + parser = argparse.ArgumentParser(description='Question Answering example using BiDAF & SQuAD') + parser.add_argument('--preprocess', type=bool, default=False, help='Preprocess dataset only') + parser.add_argument('--train', type=bool, default=False, help='Run training') + parser.add_argument('--evaluate', type=bool, default=False, help='Run evaluation on dev dataset') + parser.add_argument('--preprocessed_dataset_path', type=str, + default="preprocessed_dataset.p", help='Path to preprocessed dataset') + parser.add_argument('--preprocessed_val_dataset_path', type=str, + default="preprocessed_val_dataset.p", help='Path to preprocessed ' + 'validation dataset') + parser.add_argument('--epochs', type=int, default=12, help='Upper epoch limit') + parser.add_argument('--embedding_size', type=int, default=100, + help='Dimension of the word embedding') + parser.add_argument('--dropout', type=float, default=0.2, + help='dropout applied to layers (0 = no dropout)') + parser.add_argument('--ctx_embedding_num_layers', type=int, default=2, + help='Number of layers in Contextual embedding layer of BiDAF') + parser.add_argument('--highway_num_layers', type=int, default=2, + help='Number of layers in Highway layer of BiDAF') + parser.add_argument('--modeling_num_layers', type=int, default=2, + help='Number of layers in Modeling layer of BiDAF') + parser.add_argument('--output_num_layers', type=int, default=1, + help='Number of layers in Output layer of BiDAF') + parser.add_argument('--batch_size', type=int, default=128, help='Batch size') + parser.add_argument('--ctx_max_len', type=int, default=400, help='Maximum length of a context') + parser.add_argument('--q_max_len', type=int, default=30, help='Maximum length of a question') + parser.add_argument('--word_max_len', type=int, default=16, help='Maximum characters in a word') + parser.add_argument('--optimizer', type=str, default='adadelta', help='optimization algorithm') + parser.add_argument('--lr', type=float, default=0.5, help='Initial learning rate') + parser.add_argument('--clip', type=float, default=5.0, help='gradient clipping') + parser.add_argument('--log_interval', type=int, default=100, metavar='N', + help='report interval') + parser.add_argument('--save_dir', type=str, default='out_dir', + help='directory path to save the final model and training log') + parser.add_argument('--gpu', type=str, default=None, + help='Coma-separated ids of the gpu to use. Empty means to use cpu.') + parser.add_argument('--precision', type=str, default='float32', choices=['float16', 'float32'], + help='Use float16 or float32 precision') + parser.add_argument('--save_prediction_path', type=str, default='', + help='Path to save predictions') + parser.add_argument('--use_multiprecision_in_optimizer', type=bool, default=False, + help='When using float16, shall optimizer use multiprecision.') + + args = parser.parse_args() + return args + + def get_combined_dim(combination, tensor_dims): """ For use with :func:`combine_tensors`. This function computes the resultant dimension when @@ -153,7 +203,7 @@ def combine_tensors(combination, tensors): return nd.concat(to_concatenate, dim=-1) -def masked_softmax(vector, mask): +def masked_softmax(F, vector, mask): """ ``nd.softmax(vector)`` does not work if some elements of ``vector`` should be masked. This performs a softmax on just the non-masked portions of ``vector``. Passing @@ -166,12 +216,12 @@ def masked_softmax(vector, mask): that uses categorical cross-entropy loss. """ if mask is None: - result = nd.softmax(vector, axis=-1) + result = F.softmax(vector, axis=-1) else: # To limit numerical errors from large vector elements outside the mask, we zero these out. - result = nd.softmax(vector * mask, axis=-1) + result = F.softmax(vector * mask, axis=-1) result = result * mask - result = result / (result.sum(axis=1, keepdims=True) + 1e-13) + result = F.broadcast_div(result, (result.sum(axis=1, keepdims=True) + 1e-13)) return result @@ -203,9 +253,12 @@ def masked_log_softmax(vector, mask): return nd.log_softmax(vector, axis=1) -def _last_dimension_applicator(function_to_apply, +def _last_dimension_applicator(F, + function_to_apply, tensor, - mask): + mask, + tensor_shape, + mask_shape): """ Takes a tensor with 3 or more dimensions and applies a function over the last dimension. We assume the tensor has shape ``(batch_size, ..., sequence_length)`` and that the mask (if given) @@ -213,36 +266,38 @@ def _last_dimension_applicator(function_to_apply, has the same shape as the tensor, then flatten them both to be 2D, pass them through the function and put the tensor back in its original shape. """ - tensor_shape = tensor.shape - reshaped_tensor = tensor.reshape(-1, tensor.shape[-1]) + reshaped_tensor = tensor.reshape(shape=(-1, tensor_shape[-1])) + if mask is not None: - while len(mask.shape) < len(tensor.shape): + shape_difference = len(tensor_shape) - len(mask_shape) + for i in range(0, shape_difference): mask = mask.expand_dims(1) - mask = mask.broadcast_to(shape=tensor.shape) - mask = mask.reshape(-1, mask.shape[-1]) - reshaped_result = function_to_apply(reshaped_tensor, mask) - return reshaped_result.reshape(*tensor_shape) + mask = mask.broadcast_to(shape=tensor_shape) + mask = mask.reshape(shape=(-1, mask_shape[-1])) + reshaped_result = function_to_apply(F, reshaped_tensor, mask) + return reshaped_result.reshape(shape=tensor_shape) + return reshaped_result -def last_dim_softmax(tensor, mask): +def last_dim_softmax(F, tensor, mask, tensor_shape, mask_shape): """ Takes a tensor with 3 or more dimensions and does a masked softmax over the last dimension. We assume the tensor has shape ``(batch_size, ..., sequence_length)`` and that the mask (if given) has shape ``(batch_size, sequence_length)``. """ - return _last_dimension_applicator(masked_softmax, tensor, mask) + return _last_dimension_applicator(F, masked_softmax, tensor, mask, tensor_shape, mask_shape) -def last_dim_log_softmax(tensor, mask): +def last_dim_log_softmax(F, tensor, mask, tensor_shape, mask_shape): """ Takes a tensor with 3 or more dimensions and does a masked log softmax over the last dimension. We assume the tensor has shape ``(batch_size, ..., sequence_length)`` and that the mask (if given) has shape ``(batch_size, sequence_length)``. """ - return _last_dimension_applicator(masked_log_softmax, tensor, mask) + return _last_dimension_applicator(F, masked_log_softmax, tensor, mask, tensor_shape, mask_shape) -def weighted_sum(F, matrix, attention): +def weighted_sum(F, matrix, attention, matrix_shape, attention_shape): """ Takes a matrix of vectors and a set of weights over the rows in the matrix (which we call an "attention" vector), and returns a weighted sum of the rows in the matrix. This is the typical @@ -266,21 +321,21 @@ def weighted_sum(F, matrix, attention): ``(batch_size, num_documents, num_queries, embedding_dim)`` respectively. """ - if len(attention.shape) == 2 and len(matrix.shape) == 3: + if len(attention_shape) == 2 and len(matrix_shape) == 3: return F.squeeze(F.batch_dot(attention.expand_dims(1), matrix), axis=1) - if len(attention.shape) == 3 and len(matrix.shape) == 3: + if len(attention_shape) == 3 and len(matrix_shape) == 3: return F.batch_dot(attention, matrix) - if len(matrix.shape) - 1 < len(attention.shape): - expanded_size = list(matrix.shape) - for i in range(len(attention.shape) - len(matrix.shape) + 1): + if len(matrix_shape) - 1 < len(attention_shape): + expanded_size = list(matrix_shape) + for i in range(len(attention_shape) - len(matrix_shape) + 1): matrix = matrix.expand_dims(1) expanded_size.insert(i + 1, attention.shape[i + 1]) matrix = matrix.broadcast_to(*expanded_size) - intermediate = attention.expand_dims(-1).broadcast_to(matrix.shape) * matrix + intermediate = attention.expand_dims(-1).broadcast_to(matrix_shape) * matrix return intermediate.sum(axis=-2) -def replace_masked_values(tensor, mask, replace_with): +def replace_masked_values(F, tensor, mask, replace_with): """ Replaces all masked values in ``tensor`` with ``replace_with``. ``mask`` must be broadcastable to the same shape as ``tensor``. We require that ``tensor.dim() == mask.dim()``, as otherwise we @@ -290,4 +345,4 @@ def replace_masked_values(tensor, mask, replace_with): # the `replace_with` value. one_minus_mask = 1.0 - mask values_to_add = replace_with * one_minus_mask - return tensor * mask + values_to_add + return F.broadcast_add(F.broadcast_mul(tensor, mask), values_to_add) diff --git a/scripts/tests/test_question_answering.py b/scripts/tests/test_question_answering.py index 567c44e691..5d16534203 100644 --- a/scripts/tests/test_question_answering.py +++ b/scripts/tests/test_question_answering.py @@ -19,13 +19,20 @@ import os import pytest -from mxnet import init, nd +from mxnet import init, nd, autograd +from mxnet.gluon import Trainer from mxnet.gluon.data import DataLoader, SimpleDataset +from mxnet.gluon.loss import SoftmaxCrossEntropyLoss +from mxnet.gluon.rnn import LSTM +from types import SimpleNamespace import gluonnlp as nlp from gluonnlp.data import SQuAD +from scripts.question_answering.bidaf import BidirectionalAttentionFlow from scripts.question_answering.data_processing import SQuADTransform, VocabProvider from scripts.question_answering.question_answering import * +from scripts.question_answering.similarity_function import DotProductSimilarity +from scripts.question_answering.train_question_answering import get_record_per_answer_span question_max_length = 30 context_max_length = 400 @@ -91,24 +98,68 @@ def test_bidaf_embedding(): if i < batch_size]) # need to remove question id before feeding the data to data loader - loadable_data = SimpleDataset([(r[0], r[2], r[3], r[4], r[5], r[6]) for r in processed_dataset]) - dataloader = DataLoader(loadable_data, batch_size=5) + loadable_data, dataloader = get_record_per_answer_span(processed_dataset, get_args(batch_size)) word_vocab = vocab_provider.get_word_level_vocab() word_vocab.set_embedding(nlp.embedding.create('glove', source='glove.6B.100d')) char_vocab = vocab_provider.get_char_level_vocab() - embedding = BiDAFEmbedding(word_vocab=word_vocab, char_vocab=char_vocab) + embedding = BiDAFEmbedding(word_vocab=word_vocab, + char_vocab=char_vocab, + batch_size=batch_size, + max_seq_len=question_max_length, + precision="float16") + embedding.cast("float16") embedding.initialize(init.Xavier(magnitude=2.24)) + embedding.hybridize(static_alloc=True) + state = embedding.begin_state() + + trainer = Trainer(embedding.collect_params(), "sgd", {"learning_rate": 0.1, + "multi_precision": True}) + + for i, (data, label) in enumerate(dataloader): + with autograd.record(): + record_index, q_words, ctx_words, q_chars, ctx_chars = data + q_words = q_words.astype("float16") + ctx_words = ctx_words.astype("float16") + q_chars = q_chars.astype("float16") + ctx_chars = ctx_chars.astype("float16") + label = label.astype("float16") + # passing only question_words_nd and question_chars_nd batch + out = embedding(q_words, q_chars, state) + assert out is not None + + out.backward() + trainer.step(batch_size) + break - contextual_embedding_h0 = nd.random.uniform(shape=(4, batch_size, 100)) - contextual_embedding_c0 = nd.random.uniform(shape=(4, batch_size, 100)) - for i, data in enumerate(dataloader): - # passing only question_words_nd and question_chars_nd batch - out = embedding([data[1], data[3]], [contextual_embedding_h0, contextual_embedding_c0]) - assert out is not None - break +def test_attention_layer(): + batch_size = 5 + + ctx_fake_data = nd.random.uniform(shape=(batch_size, context_max_length, 2 * embedding_size), + dtype="float16") + + q_fake_data = nd.random.uniform(shape=(batch_size, question_max_length, 2 * embedding_size), + dtype="float16") + + ctx_fake_mask = nd.ones(shape=(batch_size, context_max_length), dtype="float16") + q_fake_mask = nd.ones(shape=(batch_size, question_max_length), dtype="float16") + + layer = BidirectionalAttentionFlow(DotProductSimilarity(), + batch_size, + context_max_length, + question_max_length, + 2 * embedding_size) + + layer.cast("float16") + layer.initialize() + layer.hybridize(static_alloc=True) + + with autograd.record(): + output = layer(ctx_fake_data, q_fake_data, q_fake_mask, ctx_fake_mask) + + assert output.shape == (batch_size, context_max_length, 8 * embedding_size) def test_modeling_layer(): @@ -117,15 +168,24 @@ def test_modeling_layer(): # The modeling layer receive input in a shape of batch_size x T x 8d # T is the sequence length of context which is context_max_length # d is the size of embedding, which is embedding_size - fake_data = nd.random.uniform(shape=(batch_size, context_max_length, 8 * embedding_size)) + fake_data = nd.random.uniform(shape=(batch_size, context_max_length, 8 * embedding_size), + dtype="float16") # We assume that attention is already return data in TNC format attention_output = nd.transpose(fake_data, axes=(1, 0, 2)) - layer = BiDAFModelingLayer() - # The model doesn't need to know the hidden states, so I don't hold variables for the states + layer = BiDAFModelingLayer(batch_size, precision="float16") + layer.cast("float16") layer.initialize() + layer.hybridize(static_alloc=True) + state = layer.begin_state() + + trainer = Trainer(layer.collect_params(), "sgd", {"learning_rate": "0.1", + "multi_precision": True}) + + with autograd.record(): + output = layer(attention_output, state) - output = layer(attention_output) + output.backward() # According to the paper, the output should be 2d x T assert output.shape == (context_max_length, batch_size, 2 * embedding_size) @@ -138,14 +198,98 @@ def test_output_layer(): # (batch_size, context_max_length, 8 * embedding_size) # The modeling layer returns data in TNC format - modeling_output = nd.random.uniform(shape=(context_max_length, batch_size, 2 * embedding_size)) + modeling_output = nd.random.uniform(shape=(context_max_length, batch_size, 2 * embedding_size), + dtype="float16") # The layer assumes that attention is already return data in TNC format - attention_output = nd.random.uniform(shape=(context_max_length, batch_size, 8 * embedding_size)) + attention_output = nd.random.uniform(shape=(context_max_length, batch_size, 8 * embedding_size), + dtype="float16") - layer = BiDAFOutputLayer() + layer = BiDAFOutputLayer(batch_size, precision="float16") + layer.cast("float16") # The model doesn't need to know the hidden states, so I don't hold variables for the states layer.initialize() + layer.hybridize(static_alloc=True) + state = layer.begin_state() - output = layer(attention_output, modeling_output) + trainer = Trainer(layer.collect_params(), "sgd", {"learning_rate": 0.1, + "multi_precision": True}) + + with autograd.record(): + output = layer(attention_output, modeling_output, state) + + output.backward() # We expect final numbers as batch_size x 2 (first start index, second end index) - assert output.shape == (batch_size, 2) + assert output.shape == (batch_size, 2, 400) + + +def test_bidaf_model(): + options = get_args(batch_size=5) + + dataset = SQuAD(segment='dev', root='tests/data/squad') + vocab_provider = VocabProvider(dataset) + transformer = SQuADTransform(vocab_provider, question_max_length, + context_max_length, max_chars_per_word) + + # for performance reason, process only batch_size # of records + processed_dataset = SimpleDataset([transformer(*record) for i, record in enumerate(dataset) + if i < options.batch_size]) + + # need to remove question id before feeding the data to data loader + loadable_data, dataloader = get_record_per_answer_span(processed_dataset, options) + + word_vocab = vocab_provider.get_word_level_vocab() + word_vocab.set_embedding(nlp.embedding.create('glove', source='glove.6B.100d')) + char_vocab = vocab_provider.get_char_level_vocab() + + model = BiDAFModel(word_vocab=word_vocab, + char_vocab=char_vocab, + options=options) + + model.cast("float16") + model.initialize(init.Xavier(magnitude=2.24)) + model.hybridize(static_alloc=True) + + ctx_embedding_begin_state = model.ctx_embedding.begin_state() + q_embedding_begin_state = model.q_embedding.begin_state() + m_layer_begin_state = model.modeling_layer.begin_state() + o_layer_begin_state = model.output_layer.begin_state() + + loss_function = SoftmaxCrossEntropyLoss() + trainer = Trainer(model.collect_params(), "adadelta", {"learning_rate": 0.5, + "multi_precision": True}) + + for i, (data, label) in enumerate(dataloader): + record_index, q_words, ctx_words, q_chars, ctx_chars = data + q_words = q_words.astype("float16") + ctx_words = ctx_words.astype("float16") + q_chars = q_chars.astype("float16") + ctx_chars = ctx_chars.astype("float16") + label = label.astype("float16") + + with autograd.record(): + out = model(record_index, q_words, ctx_words, q_chars, ctx_chars, + ctx_embedding_begin_state, q_embedding_begin_state, + m_layer_begin_state, o_layer_begin_state) + loss = loss_function(out, label) + + loss.backward() + trainer.step(options.batch_size) + break + + +def get_args(batch_size): + options = SimpleNamespace() + options.ctx_embedding_num_layers = 2 + options.embedding_size = 100 + options.dropout = 0.2 + options.ctx_embedding_num_layers = 2 + options.highway_num_layers = 2 + options.modeling_num_layers = 2 + options.output_num_layers = 2 + options.batch_size = batch_size + options.ctx_max_len = context_max_length + options.q_max_len = question_max_length + options.word_max_len = max_chars_per_word + options.precision = "float16" + + return options From 5e4975be3ff2481fa873b1fc6aa62ff339941e46 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Fri, 28 Sep 2018 10:00:06 -0700 Subject: [PATCH 14/43] Float16 + hybridize works. TODO:replace hard codes --- .../question_answering/question_answering.py | 103 +++++++++--------- .../train_question_answering.py | 54 ++++++--- scripts/tests/test_question_answering.py | 4 +- 3 files changed, 91 insertions(+), 70 deletions(-) diff --git a/scripts/question_answering/question_answering.py b/scripts/question_answering/question_answering.py index dd61472dfa..a42a1b33cd 100644 --- a/scripts/question_answering/question_answering.py +++ b/scripts/question_answering/question_answering.py @@ -47,32 +47,36 @@ def __init__(self, batch_size, word_vocab, char_vocab, max_seq_len, self._max_seq_len = max_seq_len self._precision = precision self._embedding_size = embedding_size - self._char_dense_embedding = nn.Embedding(input_dim=len(char_vocab), output_dim=8) - self._char_conv_embedding = ConvolutionalEncoder( - embed_size=8, - num_filters=(100,), - ngram_filter_sizes=(5,), - num_highway=None, - conv_layer_activation='relu', - output_size=None - ) - - self._word_embedding = nn.Embedding(input_dim=len(word_vocab), output_dim=embedding_size) - - self._highway_network = Highway(2 * embedding_size, num_layers=highway_nlayers) - self._contextual_embedding = LSTM(hidden_size=embedding_size, - num_layers=contextual_embedding_nlayers, - bidirectional=True, input_size=2 * embedding_size) + + with self.name_scope(): + self._char_dense_embedding = nn.Embedding(input_dim=len(char_vocab), + output_dim=8) + self._char_conv_embedding = ConvolutionalEncoder( + embed_size=8, + num_filters=(100,), + ngram_filter_sizes=(5,), + num_highway=None, + conv_layer_activation='relu', + output_size=None + ) + + self._word_embedding = nn.Embedding(input_dim=len(word_vocab), + output_dim=embedding_size) + + self._highway_network = Highway(2 * embedding_size, num_layers=highway_nlayers) + self._contextual_embedding = LSTM(hidden_size=embedding_size, + num_layers=contextual_embedding_nlayers, + bidirectional=True, input_size=2 * embedding_size) def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False, force_reinit=False): super(BiDAFEmbedding, self).initialize(init, ctx, verbose, force_reinit) self._word_embedding.weight.set_data(self._word_vocab.embedding.idx_to_vec) - def begin_state(self): - state = self._contextual_embedding.begin_state(self._batch_size, - dtype=self._precision) - - return state + def begin_state(self, ctx): + state_list = [self._contextual_embedding.begin_state(self._batch_size, + dtype=self._precision, + ctx=c) for c in ctx] + return state_list def hybrid_forward(self, F, w, c, contextual_embedding_state, *args): # Changing shape from NTC to TNC as most MXNet blocks work with TNC format natively @@ -95,15 +99,15 @@ def hybrid_forward(self, F, w, c, contextual_embedding_state, *args): def convolute(token_of_all_batches, _): return self._char_conv_embedding(token_of_all_batches), [] - token_list, _ = F.contrib.foreach(convolute, char_level_data, []) + char_embedded, _ = F.contrib.foreach(convolute, char_level_data, []) # Step 4. Concat all tokens embeddings to create a single tensor. - char_embedded = F.concat(*token_list, dim=0) + # char_embedded = F.concat(*token_list, dim=0) # Step 5. Reshape tensor to match dimensions of embedded words - char_embedded = char_embedded.reshape(shape=(self._max_seq_len, - self._batch_size, - self._embedding_size)) + # char_embedded = char_embedded.reshape(shape=(self._max_seq_len, + # self._batch_size, + # self._embedding_size)) # Concat embeddings, making channels size = 200 highway_input = F.concat(char_embedded, word_embedded, dim=2) @@ -111,9 +115,6 @@ def convolute(token_of_all_batches, _): # Pass through highway, shape remains unchanged highway_output = self._highway_network(highway_input) - if contextual_embedding_state is None: - contextual_embedding_state = self.begin_state() - ce_output, ce_state = self._contextual_embedding(highway_output, contextual_embedding_state) return ce_output @@ -149,18 +150,18 @@ def __init__(self, batch_size, input_dim=100, nlayers=2, biflag=True, self._batch_size = batch_size self._precision = precision - self._modeling_layer = LSTM(hidden_size=input_dim, num_layers=nlayers, dropout=dropout, - bidirectional=biflag, input_size=800) - def begin_state(self): - state = self._modeling_layer.begin_state(self._batch_size, - dtype=self._precision) - return state + with self.name_scope(): + self._modeling_layer = LSTM(hidden_size=input_dim, num_layers=nlayers, dropout=dropout, + bidirectional=biflag, input_size=800) - def hybrid_forward(self, F, x, state, *args): - if state is None: - state = self.begin_state() + def begin_state(self, ctx): + state_list = [self._modeling_layer.begin_state(self._batch_size, + dtype=self._precision, + ctx=c) for c in ctx] + return state_list + def hybrid_forward(self, F, x, state, *args): out, _ = self._modeling_layer(x, state) return out @@ -203,16 +204,19 @@ def __init__(self, batch_size, span_start_input_dim=100, units=None, nlayers=1, self._batch_size = batch_size self._precision = precision - self._start_index_dense = nn.Dense(units=units) - self._end_index_lstm = LSTM(hidden_size=span_start_input_dim, - num_layers=nlayers, dropout=dropout, bidirectional=biflag, - input_size=200) - self._end_index_dense = nn.Dense(units=units) - def begin_state(self): - state = self._end_index_lstm.begin_state(self._batch_size, - dtype=self._precision) - return state + with self.name_scope(): + self._start_index_dense = nn.Dense(units=units, in_units=400000) + self._end_index_lstm = LSTM(hidden_size=span_start_input_dim, + num_layers=nlayers, dropout=dropout, bidirectional=biflag, + input_size=200) + self._end_index_dense = nn.Dense(units=units, in_units=400000) + + def begin_state(self, ctx): + state_list = [self._end_index_lstm.begin_state(self._batch_size, + dtype=self._precision, + ctx=c) for c in ctx] + return state_list def hybrid_forward(self, F, x, m, state, *args): # pylint: disable=arguments-differ @@ -220,9 +224,6 @@ def hybrid_forward(self, F, x, m, state, *args): # pylint: disable=arguments-di start_index_input = F.transpose(F.concat(x, m, dim=2), axes=(1, 0, 2)) start_index_dense_output = self._start_index_dense(start_index_input) - if state is None: - state = self.begin_state() - end_index_input_part, _ = self._end_index_lstm(m, state) end_index_input = F.transpose(F.concat(x, end_index_input_part, dim=2), axes=(1, 0, 2)) @@ -289,7 +290,7 @@ def __init__(self, word_vocab, char_vocab, options, prefix=None, params=None): dropout=options.dropout, precision=options.precision) - def hybrid_forward(self, F, ri, qw, cw, qc, cc, + def hybrid_forward(self, F, qw, cw, qc, cc, ctx_embedding_states=None, q_embedding_states=None, modeling_layer_states=None, diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index ce99b05f5e..65602335dc 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -102,8 +102,12 @@ def get_record_per_answer_span(processed_dataset, options): global_index += 1 loadable_data = ArrayDataset(data_no_label, labels) - dataloader = DataLoader(loadable_data, batch_size=options.batch_size, shuffle=True, - last_batch='keep') + dataloader = DataLoader(loadable_data, + batch_size=options.batch_size * len(get_context(options)), + shuffle=True, + last_batch='discard', + num_workers=(multiprocessing.cpu_count() - + len(get_context(options)) - 2)) return loadable_data, dataloader @@ -145,6 +149,7 @@ def get_context(options): if options.gpu is None: ctx.append(mx.cpu(0)) + ctx.append(mx.cpu(1)) print('Use CPU') else: indices = options.gpu.split(',') @@ -155,7 +160,7 @@ def get_context(options): return ctx -def run_training(net, dataloader, evaluator, ctx, options): +def run_training(net, dataloader, ctx, options): """Main function to do training of the network Parameters @@ -164,19 +169,25 @@ def run_training(net, dataloader, evaluator, ctx, options): Network to train dataloader : `DataLoader` Initialized dataloader - evaluator: `PerformanceEvaluator` - Used to plug in official evaluation script ctx: `Context` Training context options : `Namespace` Training arguments """ - trainer = Trainer(net.collect_params(), args.optimizer, - {'learning_rate': options.lr}, - kvstore="device") + hyperparameters = {'learning_rate': options.lr} + + if options.precision == 'float16' and options.use_multiprecision_in_optimizer: + hyperparameters["multi_precision"] = True + + trainer = Trainer(net.collect_params(), args.optimizer, hyperparameters, kvstore="device") loss_function = SoftmaxCrossEntropyLoss() + ctx_embedding_begin_state_list = net.ctx_embedding.begin_state(ctx) + q_embedding_begin_state_list = net.q_embedding.begin_state(ctx) + m_layer_begin_state_list = net.modeling_layer.begin_state(ctx) + o_layer_begin_state_list = net.output_layer.begin_state(ctx) + train_start = time() avg_loss = mx.nd.zeros((1,), ctx=ctx[0], dtype=options.precision) print("Starting training...") @@ -209,10 +220,20 @@ def run_training(net, dataloader, evaluator, ctx, options): mx.nd.waitall() losses = [] - for ri, qw, cw, qc, cc, l in zip(record_index, q_words, ctx_words, - q_chars, ctx_chars, label): + for ri, qw, cw, qc, cc, l, ctx_embedding_begin_state, \ + q_embedding_begin_state, m_layer_begin_state, \ + o_layer_begin_state in zip(record_index, q_words, ctx_words, + q_chars, ctx_chars, label, + ctx_embedding_begin_state_list, + q_embedding_begin_state_list, + m_layer_begin_state_list, + o_layer_begin_state_list): with autograd.record(): - o, _, _ = net(ri, qw, cw, qc, cc) + o = net(qw, cw, qc, cc, + ctx_embedding_begin_state, + q_embedding_begin_state, + m_layer_begin_state, + o_layer_begin_state) loss = loss_function(o, l) losses.append(loss) @@ -293,6 +314,7 @@ def load_transformed_dataset(path): if __name__ == "__main__": args = get_args() + args.batch_size = int(args.batch_size / len(get_context(args))) print(args) logging_config(args.save_dir) @@ -328,15 +350,11 @@ def load_transformed_dataset(path): evaluator = PerformanceEvaluator(transformed_dataset, dataset._read_data(), mapper) net = BiDAFModel(word_vocab, char_vocab, args, prefix="bidaf") - net.initialize(init.Xavier(magnitude=2.24), ctx=ctx) net.cast(args.precision) - #net._ctx_embedding.hybridize() - #net._q_embedding.hybridize() - net._attention_layer.hybridize() - net._modeling_layer.hybridize() - net._output_layer.hybridize() + net.initialize(init.Xavier(magnitude=2.24), ctx=ctx) + net.hybridize(static_alloc=True) - run_training(net, train_dataloader, evaluator, ctx, options=args) + run_training(net, train_dataloader, ctx, options=args) if args.evaluate: print("Running in evaluation mode") diff --git a/scripts/tests/test_question_answering.py b/scripts/tests/test_question_answering.py index 5d16534203..e05d0527d1 100644 --- a/scripts/tests/test_question_answering.py +++ b/scripts/tests/test_question_answering.py @@ -256,7 +256,7 @@ def test_bidaf_model(): loss_function = SoftmaxCrossEntropyLoss() trainer = Trainer(model.collect_params(), "adadelta", {"learning_rate": 0.5, - "multi_precision": True}) + "multi_precision": True}) for i, (data, label) in enumerate(dataloader): record_index, q_words, ctx_words, q_chars, ctx_chars = data @@ -276,6 +276,8 @@ def test_bidaf_model(): trainer.step(options.batch_size) break + nd.waitall() + def get_args(batch_size): options = SimpleNamespace() From 69ea71c1d14b975c38ff226019b807faf85c8d38 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Fri, 28 Sep 2018 10:16:48 -0700 Subject: [PATCH 15/43] Hard code removed --- scripts/question_answering/question_answering.py | 7 ++++--- scripts/question_answering/train_question_answering.py | 8 ++------ 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/scripts/question_answering/question_answering.py b/scripts/question_answering/question_answering.py index a42a1b33cd..e0ca9c029a 100644 --- a/scripts/question_answering/question_answering.py +++ b/scripts/question_answering/question_answering.py @@ -201,16 +201,17 @@ def __init__(self, batch_size, span_start_input_dim=100, units=None, nlayers=1, super(BiDAFOutputLayer, self).__init__(prefix=prefix, params=params) units = 4 * span_start_input_dim if units is None else units + embedding_size = 1000 self._batch_size = batch_size self._precision = precision with self.name_scope(): - self._start_index_dense = nn.Dense(units=units, in_units=400000) + self._start_index_dense = nn.Dense(units=units, in_units=units * embedding_size) self._end_index_lstm = LSTM(hidden_size=span_start_input_dim, num_layers=nlayers, dropout=dropout, bidirectional=biflag, - input_size=200) - self._end_index_dense = nn.Dense(units=units, in_units=400000) + input_size=2 * span_start_input_dim) + self._end_index_dense = nn.Dense(units=units, in_units=units * embedding_size) def begin_state(self, ctx): state_list = [self._end_index_lstm.begin_state(self._batch_size, diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index 65602335dc..750c14d185 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -246,10 +246,6 @@ def run_training(net, dataloader, ctx, options): avg_loss += l.mean().as_in_context(avg_loss.context) mx.nd.waitall() - print("Start evaluate performance") - #eval_results = evaluator.evaluate_performance(net, ctx, options) - eval_results = {} - print("End evaluate performance") avg_loss /= (i * len(ctx)) @@ -258,9 +254,9 @@ def run_training(net, dataloader, ctx, options): epoch_time = time() - e_start print("\tEPOCH {:2}: train loss {:4.2f} | batch {:4} | lr {:5.3f} | " - "Time per epoch {:5.2f} seconds | {}" + "Time per epoch {:5.2f} seconds" .format(e, avg_loss_scalar, options.batch_size, trainer.learning_rate, - epoch_time, eval_results)) + epoch_time)) save_model_parameters(net, e, options) From 74707fc25f81e44d56e6fe7f0933355e571b7476 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Fri, 28 Sep 2018 16:20:06 -0700 Subject: [PATCH 16/43] Some useless code removed --- scripts/question_answering/similarity_function.py | 1 + .../train_question_answering.py | 15 +++++++-------- scripts/question_answering/utils.py | 1 - 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/scripts/question_answering/similarity_function.py b/scripts/question_answering/similarity_function.py index 26616a6729..ec4900869b 100644 --- a/scripts/question_answering/similarity_function.py +++ b/scripts/question_answering/similarity_function.py @@ -66,6 +66,7 @@ def hybrid_forward(self, F, array_1, array_2): result = (array_1 * array_2).sum(axis=-1) if self._scale_output: + # result *= F.sqrt(array_1.shape[-1]) result *= F.contrib.div_sqrt_dim(array_1) return result diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index 750c14d185..b0000054ed 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -104,7 +104,7 @@ def get_record_per_answer_span(processed_dataset, options): loadable_data = ArrayDataset(data_no_label, labels) dataloader = DataLoader(loadable_data, batch_size=options.batch_size * len(get_context(options)), - shuffle=True, + # shuffle=True, last_batch='discard', num_workers=(multiprocessing.cpu_count() - len(get_context(options)) - 2)) @@ -148,8 +148,7 @@ def get_context(options): ctx = [] if options.gpu is None: - ctx.append(mx.cpu(0)) - ctx.append(mx.cpu(1)) + ctx.append(mx.cpu()) print('Use CPU') else: indices = options.gpu.split(',') @@ -183,11 +182,6 @@ def run_training(net, dataloader, ctx, options): trainer = Trainer(net.collect_params(), args.optimizer, hyperparameters, kvstore="device") loss_function = SoftmaxCrossEntropyLoss() - ctx_embedding_begin_state_list = net.ctx_embedding.begin_state(ctx) - q_embedding_begin_state_list = net.q_embedding.begin_state(ctx) - m_layer_begin_state_list = net.modeling_layer.begin_state(ctx) - o_layer_begin_state_list = net.output_layer.begin_state(ctx) - train_start = time() avg_loss = mx.nd.zeros((1,), ctx=ctx[0], dtype=options.precision) print("Starting training...") @@ -195,6 +189,11 @@ def run_training(net, dataloader, ctx, options): for e in range(args.epochs): avg_loss *= 0 # Zero average loss of each epoch + ctx_embedding_begin_state_list = net.ctx_embedding.begin_state(ctx) + q_embedding_begin_state_list = net.q_embedding.begin_state(ctx) + m_layer_begin_state_list = net.modeling_layer.begin_state(ctx) + o_layer_begin_state_list = net.output_layer.begin_state(ctx) + for i, (data, label) in enumerate(dataloader): # start timing for the first batch of epoch if i == 0: diff --git a/scripts/question_answering/utils.py b/scripts/question_answering/utils.py index 0ad53e5224..515b6f0bc7 100644 --- a/scripts/question_answering/utils.py +++ b/scripts/question_answering/utils.py @@ -276,7 +276,6 @@ def _last_dimension_applicator(F, mask = mask.reshape(shape=(-1, mask_shape[-1])) reshaped_result = function_to_apply(F, reshaped_tensor, mask) return reshaped_result.reshape(shape=tensor_shape) - return reshaped_result def last_dim_softmax(F, tensor, mask, tensor_shape, mask_shape): From d35c802c17f9883e216d3564c393ce9124033bf5 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Fri, 5 Oct 2018 17:50:59 -0700 Subject: [PATCH 17/43] Bug fix in data preprocessing --- gluonnlp/data/question_answering.py | 15 ++- scripts/question_answering/data_processing.py | 87 ++++++++++--- .../performance_evaluator.py | 48 +++++-- .../question_answering/question_answering.py | 33 +++-- scripts/question_answering/tokenizer.py | 60 +++++++++ .../train_question_answering.py | 105 ++++++++++++--- scripts/question_answering/utils.py | 10 +- scripts/tests/test_question_answering.py | 121 +++++++++++++++++- 8 files changed, 422 insertions(+), 57 deletions(-) create mode 100644 scripts/question_answering/tokenizer.py diff --git a/gluonnlp/data/question_answering.py b/gluonnlp/data/question_answering.py index ffc29766f0..dafc3fc22c 100644 --- a/gluonnlp/data/question_answering.py +++ b/gluonnlp/data/question_answering.py @@ -80,7 +80,7 @@ def __init__(self, segment='train', root=os.path.join('~', '.mxnet', 'datasets', self._segment = segment self._get_data() - super(SQuAD, self).__init__(self._read_data()) + super(SQuAD, self).__init__(SQuAD._get_records(self._read_data())) def _get_data(self): """Load data from the file. Does nothing if data was loaded before @@ -108,6 +108,15 @@ def _read_data(self): Question id and list_of_answers also substituted with indices, so it could be later converted into nd.array + Returns + ------- + List[Tuple] + Flatten list of questions + """ + """Read data.json from disk and flats it to the following format: + Entry = (record_index, question_id, question, context, answer_list, answer_start_indices). + Question id and list_of_answers also substituted with indices, so it could be later + converted into nd.array Returns ------- List[Tuple] @@ -116,9 +125,9 @@ def _read_data(self): _, data_file_name, _ = self._data_file[self._segment] with open(os.path.join(self._root, data_file_name)) as f: - samples = json.load(f) + json_data = json.load(f) - return SQuAD._get_records(samples) + return json_data @staticmethod def _get_records(json_dict): diff --git a/scripts/question_answering/data_processing.py b/scripts/question_answering/data_processing.py index 70d5cbf76f..52c600d993 100644 --- a/scripts/question_answering/data_processing.py +++ b/scripts/question_answering/data_processing.py @@ -19,6 +19,9 @@ # pylint: disable= """SQuAD data preprocessing.""" +from gluonnlp.data import SpacyTokenizer +from scripts.question_answering.tokenizer import BiDAFTokenizer + __all__ = ['SQuADTransform', 'VocabProvider', 'preprocess_dataset'] import re @@ -58,6 +61,7 @@ class SQuADTransform(object): def __init__(self, vocab_provider, question_max_length, context_max_length, max_chars_per_word): self._word_vocab = vocab_provider.get_word_level_vocab() self._char_vocab = vocab_provider.get_char_level_vocab() + self._tokenizer = vocab_provider.get_tokenizer() self._question_max_length = question_max_length self._context_max_length = context_max_length @@ -71,14 +75,17 @@ def __call__(self, record_index, question_id, question, context, answer_list, Method converts text into numeric arrays based on Vocabulary. Answers are not processed, as they are not needed in input """ - question_words = self._word_vocab[question.split()[:self._question_max_length]] - context_words = self._word_vocab[context.split()[:self._context_max_length]] + question_tokens = self._tokenizer(question) + context_tokens = self._tokenizer(context) + + question_words = self._word_vocab[question_tokens[:self._question_max_length]] + context_words = self._word_vocab[context_tokens[:self._context_max_length]] question_chars = [self._char_vocab[list(iter(word))] - for word in question.split()[:self._question_max_length]] + for word in question_tokens[:self._question_max_length]] context_chars = [self._char_vocab[list(iter(word))] - for word in context.split()[:self._context_max_length]] + for word in context_tokens[:self._context_max_length]] question_words_nd = self._pad_to_max_word_length(question_words, self._question_max_length) question_chars_nd = self._padder(question_chars) @@ -89,25 +96,73 @@ def __call__(self, record_index, question_id, question, context, answer_list, context_chars_nd = self._padder(context_chars) context_chars_nd = self._pad_to_max_char_length(context_chars_nd, self._context_max_length) - answer_spans = SQuADTransform._get_answer_spans(answer_list, answer_start_list) + answer_spans = SQuADTransform._get_answer_spans(context, context_tokens, answer_list, + answer_start_list) return (record_index, question_id, question_words_nd, context_words_nd, question_chars_nd, context_chars_nd, answer_spans) @staticmethod - def _get_answer_spans(answer_list, answer_start_list): - """Find all answer spans from the context, returning start_index and end_index + def _get_answer_spans(context, context_tokens, answer_list, answer_start_list): + """Find all answer spans from the context, returning start_index and end_index. + Each index is a index of a token + :param list[str] context_tokens: Tokenized paragraph :param list[str] answer_list: List of all answers - :param list[int] answer_start_list: List of all answers' start indices Returns ------- List[Tuple] list of Tuple(answer_start_index answer_end_index) per question """ - return [(answer_start_list[i], answer_start_list[i] + len(answer)) - for i, answer in enumerate(answer_list)] + answer_spans = [] + # SQuAD answers doesn't always match to used tokens in the context. Sometimes there is only + # a partial match. We use the same method as used in original implementation: + # 1. Find char index range for all tokens of context + # 2. Foreach answer + # 2.1 Find char index range for the answer (not tokenized) + # 2.2 Find Context token indices which char indices contains answer char indices + # 2.3. Return first and last token indices + context_char_indices = SQuADTransform._get_char_indices(context, context_tokens) + + for answer_start_char_index, answer in zip(answer_start_list, answer_list): + answer_token_indices = [] + answer_end_char_index = answer_start_char_index + len(answer) + + for context_token_index, context_char_span in enumerate(context_char_indices): + if not (answer_end_char_index <= context_char_span[0] or + answer_start_char_index >= context_char_span[1]): + answer_token_indices.append(context_token_index) + + if len(answer_token_indices) == 0: + print("Warning: Answer {} not found for context {}".format(answer, context)) + else: + answer_span = (answer_token_indices[0], + answer_token_indices[len(answer_token_indices) - 1]) + answer_spans.append(answer_span) + + if len(answer_spans) == 0: + print("Warning: No answers found for context {}".format(context_tokens)) + + return answer_spans + + @staticmethod + def _get_char_indices(text, text_tokens): + """Match token with character indices + + :param str text: Text + :param List[str] text_tokens: Tokens of the text + :return: List of char_indexes where the order equals to token index + """ + char_indices_per_token = [] + current_index = 0 + + for token in text_tokens: + current_index = text.find(token, current_index) + char_indices_per_token.append((current_index, current_index + len(token))) + current_index += len(token) + + return char_indices_per_token def _pad_to_max_char_length(self, item, max_item_length): """Pads all tokens to maximum size @@ -154,8 +209,13 @@ def _pad_to_max_word_length(item, max_length): class VocabProvider(object): """Provides word level and character level vocabularies """ - def __init__(self, dataset): + def __init__(self, dataset, tokenizer=BiDAFTokenizer()): self._dataset = dataset + self._tokenizer = tokenizer + + def get_tokenizer(self): + """Provides tokenizer used to create vocab""" + return self._tokenizer def get_char_level_vocab(self): """Provides character level vocabulary @@ -176,10 +236,7 @@ def get_word_level_vocab(self): Word level vocabulary """ - def simple_tokenize(source_str, token_delim=' ', seq_delim='\n'): - return list(filter(None, re.split(token_delim + '|' + seq_delim, source_str))) - - return VocabProvider._create_squad_vocab(simple_tokenize, self._dataset) + return VocabProvider._create_squad_vocab(self._tokenizer, self._dataset) @staticmethod def _create_squad_vocab(tokenization_fn, dataset): diff --git a/scripts/question_answering/performance_evaluator.py b/scripts/question_answering/performance_evaluator.py index 944d0a810b..5e62a2ece2 100644 --- a/scripts/question_answering/performance_evaluator.py +++ b/scripts/question_answering/performance_evaluator.py @@ -18,13 +18,16 @@ # under the License. """Performance evaluator - a proxy class used for plugging in official validation script""" +import multiprocessing +import re from mxnet import nd, gluon from mxnet.gluon.data import DataLoader, ArrayDataset from scripts.question_answering.metric import evaluate class PerformanceEvaluator: - def __init__(self, evaluation_dataset, json_data, question_id_mapper): + def __init__(self, tokenizer, evaluation_dataset, json_data, question_id_mapper): + self._tokenizer = tokenizer self._evaluation_dataset = evaluation_dataset self._json_data = json_data self._mapper = question_id_mapper @@ -50,7 +53,9 @@ def evaluate_performance(self, net, ctx, options): pred = {} eval_dataset = ArrayDataset([(self._mapper.question_id_to_idx[r[1]], r[2], r[3], r[4], r[5]) for r in self._evaluation_dataset]) - eval_dataloader = DataLoader(eval_dataset, batch_size=options.batch_size, last_batch='keep') + eval_dataloader = DataLoader(eval_dataset, batch_size=options.batch_size, + last_batch='keep', + num_workers=(multiprocessing.cpu_count() - len(ctx) - 2)) for i, data in enumerate(eval_dataloader): record_index, q_words, ctx_words, q_chars, ctx_chars = data @@ -67,10 +72,26 @@ def evaluate_performance(self, net, ctx, options): q_chars = gluon.utils.split_and_load(q_chars, ctx, even_split=False) ctx_chars = gluon.utils.split_and_load(ctx_chars, ctx, even_split=False) + ctx_embedding_begin_state_list = net.ctx_embedding.begin_state(ctx) + q_embedding_begin_state_list = net.q_embedding.begin_state(ctx) + m_layer_begin_state_list = net.modeling_layer.begin_state(ctx) + o_layer_begin_state_list = net.output_layer.begin_state(ctx) + outs = [] - for ri, qw, cw, qc, cc in zip(record_index, q_words, ctx_words, q_chars, ctx_chars): - out, _, _ = net((ri, qw, cw, qc, cc)) + for ri, qw, cw, qc, cc, ctx_embedding_begin_state, \ + q_embedding_begin_state, m_layer_begin_state, \ + o_layer_begin_state in zip(record_index, q_words, ctx_words, + q_chars, ctx_chars, + ctx_embedding_begin_state_list, + q_embedding_begin_state_list, + m_layer_begin_state_list, + o_layer_begin_state_list): + out = net(qw, cw, qc, cc, + ctx_embedding_begin_state, + q_embedding_begin_state, + m_layer_begin_state, + o_layer_begin_state) outs.append(out) for out in outs: @@ -83,14 +104,16 @@ def evaluate_performance(self, net, ctx, options): idx = int(idx.asscalar()) start = int(start.asscalar()) end = int(end.asscalar()) - pred[self._mapper.idx_to_question_id[idx]] = self.get_text_result(idx, - (start, end)) + question_id = self._mapper.idx_to_question_id[idx] + pred[question_id] = (start, end, self.get_text_result(idx, (start, end))) + if options.save_prediction_path: with open(options.save_prediction_path, "w") as f: for item in pred.items(): - f.write("QId {}, Answer: {}\n".format(item[0], item[1])) + f.write("{}: {}-{} Answer: {}\n".format(item[0], item[1][0], + item[1][1], item[1][2])) - return evaluate(self._json_data['data'], pred) + return evaluate(self._json_data['data'], {k: v[2] for k, v in pred.items()}) def get_text_result(self, idx, answer_span): """Converts answer span into actual text from paragraph @@ -115,16 +138,17 @@ def get_text_result(self, idx, answer_span): question_id = self._mapper.idx_to_question_id[idx] context = self._mapper.question_id_to_context[question_id] + context_tokens = self._tokenizer(context) # start index is above the context length - return cannot provide an answer - if start > len(context) - 1: + if start > len(context_tokens) - 1: return '' # end index is above the context length - let's take answer to the end of the context - if end > len(context) - 1: - end = len(context) - 1 + if end > len(context_tokens) - 1: + end = len(context_tokens) - 1 - text = ' '.join(context.split()[start:end + 1]) + text = ' '.join(context_tokens[start:end + 1]) return text @staticmethod diff --git a/scripts/question_answering/question_answering.py b/scripts/question_answering/question_answering.py index e0ca9c029a..ca59de0b15 100644 --- a/scripts/question_answering/question_answering.py +++ b/scripts/question_answering/question_answering.py @@ -60,7 +60,8 @@ def __init__(self, batch_size, word_vocab, char_vocab, max_seq_len, output_size=None ) - self._word_embedding = nn.Embedding(input_dim=len(word_vocab), + self._word_embedding = nn.Embedding(prefix="predefined_embedding_layer", + input_dim=len(word_vocab), output_dim=embedding_size) self._highway_network = Highway(2 * embedding_size, num_layers=highway_nlayers) @@ -72,10 +73,14 @@ def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False, force_ super(BiDAFEmbedding, self).initialize(init, ctx, verbose, force_reinit) self._word_embedding.weight.set_data(self._word_vocab.embedding.idx_to_vec) - def begin_state(self, ctx): - state_list = [self._contextual_embedding.begin_state(self._batch_size, + def begin_state(self, ctx, batch_sizes=None): + if batch_sizes is None: + batch_sizes = [self._batch_size] * len(ctx) + + state_list = [self._contextual_embedding.begin_state(b, dtype=self._precision, - ctx=c) for c in ctx] + ctx=c) for c, b in zip(ctx, + batch_sizes)] return state_list def hybrid_forward(self, F, w, c, contextual_embedding_state, *args): @@ -155,10 +160,14 @@ def __init__(self, batch_size, input_dim=100, nlayers=2, biflag=True, self._modeling_layer = LSTM(hidden_size=input_dim, num_layers=nlayers, dropout=dropout, bidirectional=biflag, input_size=800) - def begin_state(self, ctx): - state_list = [self._modeling_layer.begin_state(self._batch_size, + def begin_state(self, ctx, batch_sizes=None): + if batch_sizes is None: + batch_sizes = [self._batch_size] * len(ctx) + + state_list = [self._modeling_layer.begin_state(b, dtype=self._precision, - ctx=c) for c in ctx] + ctx=c) for c, b in zip(ctx, + batch_sizes)] return state_list def hybrid_forward(self, F, x, state, *args): @@ -213,10 +222,14 @@ def __init__(self, batch_size, span_start_input_dim=100, units=None, nlayers=1, input_size=2 * span_start_input_dim) self._end_index_dense = nn.Dense(units=units, in_units=units * embedding_size) - def begin_state(self, ctx): - state_list = [self._end_index_lstm.begin_state(self._batch_size, + def begin_state(self, ctx, batch_sizes=None): + if batch_sizes is None: + batch_sizes = [self._batch_size] * len(ctx) + + state_list = [self._end_index_lstm.begin_state(b, dtype=self._precision, - ctx=c) for c in ctx] + ctx=c) for c, b in zip(ctx, + batch_sizes)] return state_list def hybrid_forward(self, F, x, m, state, *args): # pylint: disable=arguments-differ diff --git a/scripts/question_answering/tokenizer.py b/scripts/question_answering/tokenizer.py new file mode 100644 index 0000000000..a1a561eaf0 --- /dev/null +++ b/scripts/question_answering/tokenizer.py @@ -0,0 +1,60 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import re + +from gluonnlp.data import SpacyTokenizer + + +class BiDAFTokenizer: + def __init__(self, base_tokenizer=SpacyTokenizer(), lower_case=False): + self._base_tokenizer = base_tokenizer + self._lower_case = lower_case + + def __call__(self, sample): + """ + + Parameters + ---------- + sample: str + The sentence to tokenize + + Returns + ------- + ret : list of strs + List of tokens + """ + tokens = [token.replace("''", '"').replace("``", '"') for token in + self._base_tokenizer(sample)] + + if self._lower_case: + tokens = [token.lower() for token in tokens] + + tokens = BiDAFTokenizer._process_tokens(tokens) + return tokens + + @staticmethod + def _process_tokens(temp_tokens): + tokens = [] + splitters = ("-", "\u2212", "\u2014", "\u2013", "/", "~", '"', "'", "\u201C", + "\u2019", "\u201D", "\u2018", "\u00B0") + + for token in temp_tokens: + tokens.extend(re.split("([{}])".format("".join(splitters)), token)) + + return tokens diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index b0000054ed..1e3401b52e 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -16,6 +16,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import math + import multiprocessing import os from os.path import isfile @@ -35,12 +37,13 @@ from mxnet.gluon.loss import SoftmaxCrossEntropyLoss import gluonnlp as nlp -from gluonnlp.data import SQuAD +from gluonnlp.data import SQuAD, SpacyTokenizer from scripts.question_answering.data_processing import VocabProvider, SQuADTransform from scripts.question_answering.performance_evaluator import PerformanceEvaluator from scripts.question_answering.question_answering import * from scripts.question_answering.question_id_mapper import QuestionIdMapper +from scripts.question_answering.tokenizer import BiDAFTokenizer from scripts.question_answering.utils import logging_config, get_args np.random.seed(100) @@ -85,7 +88,9 @@ def get_record_per_answer_span(processed_dataset, options): Returns ------- data : Tuple - A tuple of dataset and dataloader + A tuple of dataset and dataloader. Each item in dataset is: + (index, question_word_index, context_word_index, question_char_index, context_char_index, + answers) """ data_no_label = [] labels = [] @@ -95,6 +100,11 @@ def get_record_per_answer_span(processed_dataset, options): for r in processed_dataset: # creating a set out of answer_span will deduplicate them for answer_span in set(r[6]): + # if after all preprocessing the answer is not in the context anymore, + # the item is filtered out + if options.filter_long_context and (answer_span[0] > r[3].size or + answer_span[1] > r[3].size): + continue # need to remove question id before feeding the data to data loader # And I also replace index with global_index when unrolling answers data_no_label.append((global_index, r[2], r[3], r[4], r[5])) @@ -104,11 +114,12 @@ def get_record_per_answer_span(processed_dataset, options): loadable_data = ArrayDataset(data_no_label, labels) dataloader = DataLoader(loadable_data, batch_size=options.batch_size * len(get_context(options)), - # shuffle=True, + shuffle=True, last_batch='discard', num_workers=(multiprocessing.cpu_count() - len(get_context(options)) - 2)) + print("Total records for training: {}".format(len(labels))) return loadable_data, dataloader @@ -174,16 +185,18 @@ def run_training(net, dataloader, ctx, options): Training arguments """ - hyperparameters = {'learning_rate': options.lr} + hyperparameters = {'learning_rate': options.lr, 'clip_gradient': options.clip} if options.precision == 'float16' and options.use_multiprecision_in_optimizer: hyperparameters["multi_precision"] = True - trainer = Trainer(net.collect_params(), args.optimizer, hyperparameters, kvstore="device") + trainer = Trainer(net.collect_params(), args.optimizer, hyperparameters, kvstore="device", + update_on_kvstore=False) loss_function = SoftmaxCrossEntropyLoss() train_start = time() avg_loss = mx.nd.zeros((1,), ctx=ctx[0], dtype=options.precision) + iteration = 1 print("Starting training...") for e in range(args.epochs): @@ -239,11 +252,25 @@ def run_training(net, dataloader, ctx, options): for l in losses: l.backward() - trainer.step(options.batch_size) + # if iteration == 1: + # for name, param in net.collect_params().items(): + # ema.add(name, param.data(CTX[0])) + + trainer.set_learning_rate(get_learning_rate_per_iteration(iteration, options)) + trainer.allreduce_grads() + # gradients = decay_gradients(net, ctx, options) + # gluon.utils.clip_global_norm(gradients, options.clip) + reset_embedding_gradients(net, ctx) + trainer.update(options.batch_size, ignore_stale_grad=True) + + # for name, param in net.collect_params().items(): + # ema(name, param.data(CTX[0])) for l in losses: avg_loss += l.mean().as_in_context(avg_loss.context) + iteration += 1 + mx.nd.waitall() avg_loss /= (i * len(ctx)) @@ -252,7 +279,7 @@ def run_training(net, dataloader, ctx, options): avg_loss_scalar = avg_loss.asscalar() epoch_time = time() - e_start - print("\tEPOCH {:2}: train loss {:4.2f} | batch {:4} | lr {:5.3f} | " + print("\tEPOCH {:2}: train loss {:6.4f} | batch {:4} | lr {:5.3f} | " "Time per epoch {:5.2f} seconds" .format(e, avg_loss_scalar, options.batch_size, trainer.learning_rate, epoch_time)) @@ -262,6 +289,57 @@ def run_training(net, dataloader, ctx, options): print("Training time {:6.2f} seconds".format(time() - train_start)) +def get_learning_rate_per_iteration(iteration, options): + """Returns learning rate based on current iteration. Used to implement learning rate warm up + technique + + :param int iteration: Number of iteration + :param NameSpace options: Training options + :return float: learning rate + """ + return min(options.lr, options.lr * (math.log(iteration) / math.log(options.lr_warmup_steps))) + + +def decay_gradients(model, ctx, options): + """Apply gradient decay to all layers. And only to UNK and EOS embeddings of Glove layers + + :param BiDAFModel model: Model in training + :param ctx: Contexts + :param NameSpace options: Training options + :return: Array of gradients + """ + gradients = [] + + for c in ctx: + for name, parameter in model.collect_params().items(): + grad = parameter.grad(c) + + if is_fixed_embedding_layer(name): + grad[0:2] += options.weight_decay * parameter.data(c)[0:2] + else: + grad += options.weight_decay * parameter.data(c) + gradients.append(grad) + + return gradients + + +def reset_embedding_gradients(model, ctx): + """Gradients for glove layers of both question and context embeddings doesn't need to be + trainer. We train only UNK and EOS embeddings. + + :param BiDAFModel model: Model in training + :param ctx: Contexts of training + """ + + for c in ctx: + model.q_embedding._word_embedding.weight.grad(ctx=c)[2:] = 0 + model.ctx_embedding._word_embedding.weight.grad(ctx=c)[2:] = 0 + + +def is_fixed_embedding_layer(name): + return True if "predefined_embedding_layer" in name else False + + def save_model_parameters(net, epoch, options): """Save parameters of the trained model @@ -343,7 +421,8 @@ def load_transformed_dataset(path): word_vocab, char_vocab = get_vocabs(vocab_provider, options=args) ctx = get_context(args) - evaluator = PerformanceEvaluator(transformed_dataset, dataset._read_data(), mapper) + evaluator = PerformanceEvaluator(BiDAFTokenizer(), transformed_dataset, + dataset._read_data(), mapper) net = BiDAFModel(word_vocab, char_vocab, args, prefix="bidaf") net.cast(args.precision) net.initialize(init.Xavier(magnitude=2.24), ctx=ctx) @@ -362,24 +441,20 @@ def load_transformed_dataset(path): dataset = SQuAD(segment='dev') mapper = QuestionIdMapper(dataset) - transformed_dataset = load_transformed_dataset(args.preprocessed_val_dataset_path) \ - if args.preprocessed_val_dataset_path and isfile(args.preprocessed_val_dataset_path) \ - else transform_dataset(dataset, vocab_provider, options=args) - if args.preprocessed_val_dataset_path and isfile(args.preprocessed_val_dataset_path): transformed_dataset = load_transformed_dataset(args.preprocessed_val_dataset_path) else: transformed_dataset = transform_dataset(dataset, vocab_provider, options=args) save_transformed_dataset(transformed_dataset, args.preprocessed_val_dataset_path) - val_dataset, val_dataloader = get_record_per_answer_span(transformed_dataset, args) word_vocab, char_vocab = get_vocabs(vocab_provider, options=args) ctx = get_context(args) - evaluator = PerformanceEvaluator(transformed_dataset, dataset._read_data(), mapper) + evaluator = PerformanceEvaluator(BiDAFTokenizer(), transformed_dataset, + dataset._read_data(), mapper) net = BiDAFModel(word_vocab, char_vocab, args, prefix="bidaf") net.load_parameters(model_path, ctx=ctx) + net.hybridize(static_alloc=True) result = evaluator.evaluate_performance(net, ctx, args) print("Evaluation results on dev dataset: {}".format(result)) - diff --git a/scripts/question_answering/utils.py b/scripts/question_answering/utils.py index 515b6f0bc7..6cb816a9d1 100644 --- a/scripts/question_answering/utils.py +++ b/scripts/question_answering/utils.py @@ -105,7 +105,13 @@ def get_args(): parser.add_argument('--word_max_len', type=int, default=16, help='Maximum characters in a word') parser.add_argument('--optimizer', type=str, default='adadelta', help='optimization algorithm') parser.add_argument('--lr', type=float, default=0.5, help='Initial learning rate') + parser.add_argument('--lr_warmup_steps', type=int, default=1000, + help='Defines how many iterations to spend on warming up learning rate') parser.add_argument('--clip', type=float, default=5.0, help='gradient clipping') + parser.add_argument('--weight_decay', type=float, default=3e-7, + help='Weight decay') + parser.add_argument('--exponential_moving_average_weight_decay', type=float, default=0.999, + help='Weight decay used in exponential moving average') parser.add_argument('--log_interval', type=int, default=100, metavar='N', help='report interval') parser.add_argument('--save_dir', type=str, default='out_dir', @@ -114,10 +120,12 @@ def get_args(): help='Coma-separated ids of the gpu to use. Empty means to use cpu.') parser.add_argument('--precision', type=str, default='float32', choices=['float16', 'float32'], help='Use float16 or float32 precision') + parser.add_argument('--filter_long_context', type=bool, default='True', + help='Filter contexts if the answer is after ctx_max_len') parser.add_argument('--save_prediction_path', type=str, default='', help='Path to save predictions') parser.add_argument('--use_multiprecision_in_optimizer', type=bool, default=False, - help='When using float16, shall optimizer use multiprecision.') + help='When using float16, shall optimizer use multiprecision.') args = parser.parse_args() return args diff --git a/scripts/tests/test_question_answering.py b/scripts/tests/test_question_answering.py index e05d0527d1..18383550f0 100644 --- a/scripts/tests/test_question_answering.py +++ b/scripts/tests/test_question_answering.py @@ -17,21 +17,25 @@ # specific language governing permissions and limitations # under the License. import os + import pytest +import mxnet as mx from mxnet import init, nd, autograd from mxnet.gluon import Trainer from mxnet.gluon.data import DataLoader, SimpleDataset from mxnet.gluon.loss import SoftmaxCrossEntropyLoss -from mxnet.gluon.rnn import LSTM from types import SimpleNamespace import gluonnlp as nlp from gluonnlp.data import SQuAD from scripts.question_answering.bidaf import BidirectionalAttentionFlow from scripts.question_answering.data_processing import SQuADTransform, VocabProvider +from scripts.question_answering.performance_evaluator import PerformanceEvaluator from scripts.question_answering.question_answering import * +from scripts.question_answering.question_id_mapper import QuestionIdMapper from scripts.question_answering.similarity_function import DotProductSimilarity +from scripts.question_answering.tokenizer import BiDAFTokenizer from scripts.question_answering.train_question_answering import get_record_per_answer_span question_max_length = 30 @@ -279,6 +283,119 @@ def test_bidaf_model(): nd.waitall() +def test_performance_evaluation(): + options = get_args(batch_size=5) + + train_dataset = SQuAD(segment='train') + vocab_provider = VocabProvider(train_dataset) + + dataset = SQuAD(segment='dev') + mapper = QuestionIdMapper(dataset) + + transformer = SQuADTransform(vocab_provider, question_max_length, + context_max_length, max_chars_per_word) + + # for performance reason, process only batch_size # of records + transformed_dataset = SimpleDataset([transformer(*record) for i, record in enumerate(dataset) + if i < options.batch_size]) + + word_vocab = vocab_provider.get_word_level_vocab() + word_vocab.set_embedding(nlp.embedding.create('glove', source='glove.6B.100d')) + char_vocab = vocab_provider.get_char_level_vocab() + model_path = os.path.join(options.save_dir, 'epoch{:d}.params'.format(int(options.epochs) - 1)) + + ctx = [mx.cpu()] + evaluator = PerformanceEvaluator(transformed_dataset, dataset._read_data(), mapper) + net = BiDAFModel(word_vocab, char_vocab, options, prefix="bidaf") + net.hybridize(static_alloc=True) + net.load_parameters(model_path, ctx=ctx) + + result = evaluator.evaluate_performance(net, ctx, options) + print("Evaluation results on dev dataset: {}".format(result)) + + +# def test_count_num_of_answer_index_greater_400(): +# counter_more_400 = 0 +# counter_less_400 = 0 +# train_dataset = SQuAD(segment='train') +# +# for item in train_dataset: +# for index in item[5]: +# if index >= 400: +# counter_more_400 += 1 +# else: +# counter_less_400 += 1 +# +# print("Less {}, More {}".format(counter_less_400, counter_more_400)) + + +def test_get_answer_spans_exact_match(): + tokenizer = BiDAFTokenizer() + + context = "to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary." + context_tokens = tokenizer(context) + + answer_start_index = 3 + answer = "Saint Bernadette Soubirous" + + result = SQuADTransform._get_answer_spans(context, context_tokens, + [answer], [answer_start_index]) + + assert result == [(1, 3)] + + +def test_get_answer_spans_partial_match(): + tokenizer = BiDAFTokenizer() + + context = "In addition, trucks will be allowed to enter India's capital only after 11 p.m., two hours later than the existing restriction" + context_tokens = tokenizer(context) + + answer_start_index = 72 + answer = "11 p.m" + + result = SQuADTransform._get_answer_spans(context, context_tokens, + [answer], [answer_start_index]) + + assert result == [(16, 17)] + + +def test_get_answer_spans_unicode(): + tokenizer = BiDAFTokenizer() + + context = "Back in Warsaw that year, Chopin heard Niccolò Paganini play" + context_tokens = tokenizer(context) + + answer_start_index = 39 + answer = "Niccolò Paganini" + + result = SQuADTransform._get_answer_spans(context, context_tokens, + [answer], [answer_start_index]) + + assert result == [(8, 9)] + + +def test_get_answer_spans_after_comma(): + tokenizer = BiDAFTokenizer() + + context = "Chopin's successes as a composer and performer opened the door to western Europe for him, and on 2 November 1830, he set out," + context_tokens = tokenizer(context) + + answer_start_index = 108 + answer = "1830" + + result = SQuADTransform._get_answer_spans(context, context_tokens, + [answer], [answer_start_index]) + + assert result == [(23, 23)] + +def test_get_char_indices(): + context = "to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary." + tokenizer = BiDAFTokenizer() + context_tokens = tokenizer(context) + + result = SQuADTransform._get_char_indices(context, context_tokens) + assert len(result) == len(context_tokens) + def get_args(batch_size): options = SimpleNamespace() options.ctx_embedding_num_layers = 2 @@ -293,5 +410,7 @@ def get_args(batch_size): options.q_max_len = question_max_length options.word_max_len = max_chars_per_word options.precision = "float16" + options.epochs = 100 + options.save_dir = "output/" return options From 19c37d23ab81deca5008e735cc0f8c2510a71a7b Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Wed, 10 Oct 2018 13:50:11 -0700 Subject: [PATCH 18/43] EMA is added to code + loss function change --- .../exponential_moving_average.py | 64 +++++++++ .../question_answering/question_answering.py | 7 +- .../train_question_answering.py | 127 +++++++++++++----- scripts/question_answering/utils.py | 6 +- scripts/tests/test_question_answering.py | 84 ++++++++---- 5 files changed, 218 insertions(+), 70 deletions(-) create mode 100644 scripts/question_answering/exponential_moving_average.py diff --git a/scripts/question_answering/exponential_moving_average.py b/scripts/question_answering/exponential_moving_average.py new file mode 100644 index 0000000000..e03612f2a7 --- /dev/null +++ b/scripts/question_answering/exponential_moving_average.py @@ -0,0 +1,64 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable= +"""Exponential Moving Average""" +import mxnet as mx +from mxnet import gluon + + +class PolyakAveraging: + def __init__(self, params, decay): + self._params = params + self._decay = decay + + self._polyak_params_dict = gluon.ParameterDict() + + for param in self._params.values(): + polyak_param = self._polyak_params_dict.get(param.name, shape=param.shape) + polyak_param.initialize(mx.init.Constant(self._param_data_to_cpu(param)), ctx=mx.cpu()) + + def update(self): + """ + Updates currently held saved parameters with current state of network. + + All calculations for this average occur on the cpu context. + """ + for param in self._params.values(): + polyak_param = self._polyak_params_dict.get(param.name) + polyak_param.set_data( + (1 - self._decay) * self._param_data_to_cpu(param) + + self._decay * polyak_param.data(mx.cpu())) + + def get_params(self): + """ + :return: returns the averaged parameters + :rtype: gluon.ParameterDict + """ + return self._polyak_params_dict + + def _param_data_to_cpu(self, param): + """ + Returns a copy (on CPU context) of the data held in some context of given parameter. + + :param gluon.Parameter param: parameter's whose data needs to be copied. + :return: copy of data on CPU context. + :rtype: nd.NDArray + """ + return param.list_data()[0].copyto(mx.cpu()) diff --git a/scripts/question_answering/question_answering.py b/scripts/question_answering/question_answering.py index ca59de0b15..7f2d319506 100644 --- a/scripts/question_answering/question_answering.py +++ b/scripts/question_answering/question_answering.py @@ -252,12 +252,13 @@ def hybrid_forward(self, F, x, m, state, *args): # pylint: disable=arguments-di # end_index_softmax_output = end_index_dense_output.softmax(axis=1) # end_index = F.argmax(end_index_softmax_output, axis=1) + return start_index_dense_output, end_index_dense_output # producing output in shape 2 x batch_size x units - output = F.concat(F.expand_dims(start_index_dense_output, axis=0), - F.expand_dims(end_index_dense_output, axis=0), dim=0) + # output = F.concat(F.expand_dims(start_index_dense_output, axis=0), + # F.expand_dims(end_index_dense_output, axis=0), dim=0) # transposing it to batch_size x 2 x units - return F.transpose(output, axes=(1, 0, 2)) + # return F.transpose(output, axes=(1, 0, 2)) class BiDAFModel(HybridBlock): diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index 1e3401b52e..bc27a8559e 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -20,6 +20,7 @@ import multiprocessing import os +from mxnet.gluon.loss import SoftmaxCrossEntropyLoss from os.path import isfile import logging @@ -34,12 +35,12 @@ from mxnet import gluon, init, autograd from mxnet.gluon import Trainer from mxnet.gluon.data import DataLoader, SimpleDataset, ArrayDataset -from mxnet.gluon.loss import SoftmaxCrossEntropyLoss import gluonnlp as nlp -from gluonnlp.data import SQuAD, SpacyTokenizer +from gluonnlp.data import SQuAD from scripts.question_answering.data_processing import VocabProvider, SQuADTransform +from scripts.question_answering.exponential_moving_average import PolyakAveraging from scripts.question_answering.performance_evaluator import PerformanceEvaluator from scripts.question_answering.question_answering import * from scripts.question_answering.question_id_mapper import QuestionIdMapper @@ -159,7 +160,8 @@ def get_context(options): ctx = [] if options.gpu is None: - ctx.append(mx.cpu()) + ctx.append(mx.cpu(0)) + ctx.append(mx.cpu(1)) print('Use CPU') else: indices = options.gpu.split(',') @@ -185,14 +187,14 @@ def run_training(net, dataloader, ctx, options): Training arguments """ - hyperparameters = {'learning_rate': options.lr, 'clip_gradient': options.clip} + hyperparameters = {'learning_rate': options.lr} if options.precision == 'float16' and options.use_multiprecision_in_optimizer: hyperparameters["multi_precision"] = True - trainer = Trainer(net.collect_params(), args.optimizer, hyperparameters, kvstore="device", - update_on_kvstore=False) + trainer = Trainer(net.collect_params(), args.optimizer, hyperparameters, kvstore="device") loss_function = SoftmaxCrossEntropyLoss() + ema = None train_start = time() avg_loss = mx.nd.zeros((1,), ctx=ctx[0], dtype=options.precision) @@ -241,30 +243,61 @@ def run_training(net, dataloader, ctx, options): m_layer_begin_state_list, o_layer_begin_state_list): with autograd.record(): - o = net(qw, cw, qc, cc, - ctx_embedding_begin_state, - q_embedding_begin_state, - m_layer_begin_state, - o_layer_begin_state) - loss = loss_function(o, l) + begin, end = net(qw, cw, qc, cc, + ctx_embedding_begin_state, + q_embedding_begin_state, + m_layer_begin_state, + o_layer_begin_state) + begin_end = l.split(axis=1, num_outputs=2, squeeze_axis=1) + loss = loss_function(begin, begin_end[0]) + loss_function(end, begin_end[1]) losses.append(loss) - for l in losses: - l.backward() + for loss in losses: + loss.backward() - # if iteration == 1: - # for name, param in net.collect_params().items(): - # ema.add(name, param.data(CTX[0])) + if iteration == 1 and args.use_exponential_moving_average: + ema = PolyakAveraging(net.collect_params(), + args.exponential_moving_average_weight_decay) - trainer.set_learning_rate(get_learning_rate_per_iteration(iteration, options)) - trainer.allreduce_grads() - # gradients = decay_gradients(net, ctx, options) - # gluon.utils.clip_global_norm(gradients, options.clip) - reset_embedding_gradients(net, ctx) - trainer.update(options.batch_size, ignore_stale_grad=True) + # for loss in losses: + # loss.asnumpy() + # + # print("Iteration {}.1".format(iteration)) - # for name, param in net.collect_params().items(): - # ema(name, param.data(CTX[0])) + trainer.set_learning_rate(get_learning_rate_per_iteration(iteration, options)) + trainer.step(options.batch_size) + # for loss in losses: + # loss.asnumpy() + # + # print("Iteration {}.1".format(iteration)) + # trainer.allreduce_grads() + # for loss in losses: + # loss.asnumpy() + # + # print("Iteration {}.2".format(iteration)) + # + # gradients = decay_gradients(net, ctx[0], options) + # + # for gradient in gradients: + # gradient.asnumpy() + # + # print("Iteration {}.3".format(iteration)) + # + # gluon.utils.clip_global_norm(gradients, options.clip, check_isfinite=False) + # reset_embedding_gradients(net, ctx[0]) + # + # for parameter in net.collect_params(): + # grads = parameter.list_grad() + # source = grads[0] + # destination = grads[1:] + # + # for dest in destination: + # source.copyto(dest) + # + # trainer.update(options.batch_size, ignore_stale_grad=True) + + if ema is not None: + ema.update() for l in losses: avg_loss += l.mean().as_in_context(avg_loss.context) @@ -285,6 +318,7 @@ def run_training(net, dataloader, ctx, options): epoch_time)) save_model_parameters(net, e, options) + save_ema_parameters(ema, e, options) print("Training time {:6.2f} seconds".format(time() - train_start)) @@ -310,15 +344,14 @@ def decay_gradients(model, ctx, options): """ gradients = [] - for c in ctx: - for name, parameter in model.collect_params().items(): - grad = parameter.grad(c) + for name, parameter in model.collect_params().items(): + grad = parameter.grad(ctx) - if is_fixed_embedding_layer(name): - grad[0:2] += options.weight_decay * parameter.data(c)[0:2] - else: - grad += options.weight_decay * parameter.data(c) - gradients.append(grad) + # if is_fixed_embedding_layer(name): + # grad[0:2] += options.weight_decay * parameter.data(ctx)[0:2] + # else: + # grad += options.weight_decay * parameter.data(ctx) + gradients.append(grad) return gradients @@ -330,10 +363,8 @@ def reset_embedding_gradients(model, ctx): :param BiDAFModel model: Model in training :param ctx: Contexts of training """ - - for c in ctx: - model.q_embedding._word_embedding.weight.grad(ctx=c)[2:] = 0 - model.ctx_embedding._word_embedding.weight.grad(ctx=c)[2:] = 0 + model.q_embedding._word_embedding.weight.grad(ctx=ctx)[2:] = 0 + model.ctx_embedding._word_embedding.weight.grad(ctx=ctx)[2:] = 0 def is_fixed_embedding_layer(name): @@ -360,6 +391,28 @@ def save_model_parameters(net, epoch, options): net.save_parameters(save_path) +def save_ema_parameters(ema, epoch, options): + """Save exponentially averaged parameters of the trained model + + Parameters + ---------- + ema : `PolyakAveraging` + Model with trained parameters + epoch : `int` + Number of epoch + options : `Namespace` + Saving arguments + """ + if ema is None: + return + + if not os.path.exists(options.save_dir): + os.mkdir(options.save_dir) + + save_path = os.path.join(options.save_dir, 'ema_epoch{:d}.params'.format(epoch)) + ema.get_params().save(save_path) + + def save_transformed_dataset(dataset, path): """Save processed dataset into a file. diff --git a/scripts/question_answering/utils.py b/scripts/question_answering/utils.py index 6cb816a9d1..0780efec0b 100644 --- a/scripts/question_answering/utils.py +++ b/scripts/question_answering/utils.py @@ -110,8 +110,6 @@ def get_args(): parser.add_argument('--clip', type=float, default=5.0, help='gradient clipping') parser.add_argument('--weight_decay', type=float, default=3e-7, help='Weight decay') - parser.add_argument('--exponential_moving_average_weight_decay', type=float, default=0.999, - help='Weight decay used in exponential moving average') parser.add_argument('--log_interval', type=int, default=100, metavar='N', help='report interval') parser.add_argument('--save_dir', type=str, default='out_dir', @@ -126,6 +124,10 @@ def get_args(): help='Path to save predictions') parser.add_argument('--use_multiprecision_in_optimizer', type=bool, default=False, help='When using float16, shall optimizer use multiprecision.') + parser.add_argument('--use_exponential_moving_average', type=bool, default=False, + help='Should averaged copy of parameters been stored.') + parser.add_argument('--exponential_moving_average_weight_decay', type=float, default=0.999, + help='Weight decay used in exponential moving average') args = parser.parse_args() return args diff --git a/scripts/tests/test_question_answering.py b/scripts/tests/test_question_answering.py index 18383550f0..7e5ec1b79b 100644 --- a/scripts/tests/test_question_answering.py +++ b/scripts/tests/test_question_answering.py @@ -21,7 +21,7 @@ import pytest import mxnet as mx -from mxnet import init, nd, autograd +from mxnet import init, nd, autograd, gluon from mxnet.gluon import Trainer from mxnet.gluon.data import DataLoader, SimpleDataset from mxnet.gluon.loss import SoftmaxCrossEntropyLoss @@ -33,6 +33,7 @@ from scripts.question_answering.data_processing import SQuADTransform, VocabProvider from scripts.question_answering.performance_evaluator import PerformanceEvaluator from scripts.question_answering.question_answering import * +from scripts.question_answering.question_answering import TextIndicesLoss from scripts.question_answering.question_id_mapper import QuestionIdMapper from scripts.question_answering.similarity_function import DotProductSimilarity from scripts.question_answering.tokenizer import BiDAFTokenizer @@ -228,6 +229,7 @@ def test_output_layer(): def test_bidaf_model(): options = get_args(batch_size=5) + ctx = [mx.cpu(0), mx.cpu(1)] dataset = SQuAD(segment='dev', root='tests/data/squad') vocab_provider = VocabProvider(dataset) @@ -236,7 +238,7 @@ def test_bidaf_model(): # for performance reason, process only batch_size # of records processed_dataset = SimpleDataset([transformer(*record) for i, record in enumerate(dataset) - if i < options.batch_size]) + if i < options.batch_size * len(ctx)]) # need to remove question id before feeding the data to data loader loadable_data, dataloader = get_record_per_answer_span(processed_dataset, options) @@ -245,38 +247,60 @@ def test_bidaf_model(): word_vocab.set_embedding(nlp.embedding.create('glove', source='glove.6B.100d')) char_vocab = vocab_provider.get_char_level_vocab() - model = BiDAFModel(word_vocab=word_vocab, - char_vocab=char_vocab, - options=options) + net = BiDAFModel(word_vocab=word_vocab, + char_vocab=char_vocab, + options=options) - model.cast("float16") - model.initialize(init.Xavier(magnitude=2.24)) - model.hybridize(static_alloc=True) + #net.cast("float16") + net.initialize(init.Xavier(magnitude=2.24)) + #net.hybridize(static_alloc=True) - ctx_embedding_begin_state = model.ctx_embedding.begin_state() - q_embedding_begin_state = model.q_embedding.begin_state() - m_layer_begin_state = model.modeling_layer.begin_state() - o_layer_begin_state = model.output_layer.begin_state() + ctx_embedding_begin_state_list = net.ctx_embedding.begin_state(ctx) + q_embedding_begin_state_list = net.q_embedding.begin_state(ctx) + m_layer_begin_state_list = net.modeling_layer.begin_state(ctx) + o_layer_begin_state_list = net.output_layer.begin_state(ctx) - loss_function = SoftmaxCrossEntropyLoss() - trainer = Trainer(model.collect_params(), "adadelta", {"learning_rate": 0.5, - "multi_precision": True}) + loss_function = TextIndicesLoss() + trainer = Trainer(net.collect_params(), "adadelta", {"learning_rate": 0.5, + "multi_precision": True}) for i, (data, label) in enumerate(dataloader): record_index, q_words, ctx_words, q_chars, ctx_chars = data - q_words = q_words.astype("float16") - ctx_words = ctx_words.astype("float16") - q_chars = q_chars.astype("float16") - ctx_chars = ctx_chars.astype("float16") - label = label.astype("float16") + # q_words = q_words.astype("float16") + # ctx_words = ctx_words.astype("float16") + # q_chars = q_chars.astype("float16") + # ctx_chars = ctx_chars.astype("float16") + # label = label.astype("float16") + + record_index = gluon.utils.split_and_load(record_index, ctx, even_split=False) + q_words = gluon.utils.split_and_load(q_words, ctx, even_split=False) + ctx_words = gluon.utils.split_and_load(ctx_words, ctx, even_split=False) + q_chars = gluon.utils.split_and_load(q_chars, ctx, even_split=False) + ctx_chars = gluon.utils.split_and_load(ctx_chars, ctx, even_split=False) + label = gluon.utils.split_and_load(label, ctx, even_split=False) + + losses = [] + + for ri, qw, cw, qc, cc, l, ctx_embedding_begin_state, \ + q_embedding_begin_state, m_layer_begin_state, \ + o_layer_begin_state in zip(record_index, q_words, ctx_words, + q_chars, ctx_chars, label, + ctx_embedding_begin_state_list, + q_embedding_begin_state_list, + m_layer_begin_state_list, + o_layer_begin_state_list): + with autograd.record(): + begin, end = net(qw, cw, qc, cc, + ctx_embedding_begin_state, + q_embedding_begin_state, + m_layer_begin_state, + o_layer_begin_state) + loss = loss_function(begin, end, l) + losses.append(loss) + + for loss in losses: + loss.backward() - with autograd.record(): - out = model(record_index, q_words, ctx_words, q_chars, ctx_chars, - ctx_embedding_begin_state, q_embedding_begin_state, - m_layer_begin_state, o_layer_begin_state) - loss = loss_function(out, label) - - loss.backward() trainer.step(options.batch_size) break @@ -388,6 +412,7 @@ def test_get_answer_spans_after_comma(): assert result == [(23, 23)] + def test_get_char_indices(): context = "to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary." tokenizer = BiDAFTokenizer() @@ -396,8 +421,10 @@ def test_get_char_indices(): result = SQuADTransform._get_char_indices(context, context_tokens) assert len(result) == len(context_tokens) + def get_args(batch_size): options = SimpleNamespace() + options.gpu = None options.ctx_embedding_num_layers = 2 options.embedding_size = 100 options.dropout = 0.2 @@ -409,8 +436,9 @@ def get_args(batch_size): options.ctx_max_len = context_max_length options.q_max_len = question_max_length options.word_max_len = max_chars_per_word - options.precision = "float16" + options.precision = "float32" options.epochs = 100 options.save_dir = "output/" + options.filter_long_context = False return options From 60a4374de88d82a16929793dfdf805ca4f8d58e3 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Wed, 10 Oct 2018 14:20:09 -0700 Subject: [PATCH 19/43] EMA can be used for prediction --- .../question_answering/performance_evaluator.py | 17 ++++++++--------- .../train_question_answering.py | 13 ++++++++++--- scripts/question_answering/utils.py | 3 ++- 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/scripts/question_answering/performance_evaluator.py b/scripts/question_answering/performance_evaluator.py index 5e62a2ece2..5fab5f14e7 100644 --- a/scripts/question_answering/performance_evaluator.py +++ b/scripts/question_answering/performance_evaluator.py @@ -87,17 +87,16 @@ def evaluate_performance(self, net, ctx, options): q_embedding_begin_state_list, m_layer_begin_state_list, o_layer_begin_state_list): - out = net(qw, cw, qc, cc, - ctx_embedding_begin_state, - q_embedding_begin_state, - m_layer_begin_state, - o_layer_begin_state) - outs.append(out) + begin, end = net(qw, cw, qc, cc, + ctx_embedding_begin_state, + q_embedding_begin_state, + m_layer_begin_state, + o_layer_begin_state) + outs.append((begin, end)) for out in outs: - out_per_index = out.transpose(axes=(1, 0, 2)) - start_indices = PerformanceEvaluator._get_index(out_per_index[0]) - end_indices = PerformanceEvaluator._get_index(out_per_index[1]) + start_indices = PerformanceEvaluator._get_index(out[0]) + end_indices = PerformanceEvaluator._get_index(out[1]) # iterate over batches for idx, start, end in zip(data[0], start_indices, end_indices): diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index bc27a8559e..3dd5d0f9c3 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -486,8 +486,6 @@ def load_transformed_dataset(path): if args.evaluate: print("Running in evaluation mode") # we use training dataset to build vocabs - model_path = os.path.join(args.save_dir, 'epoch{:d}.params'.format(int(args.epochs) - 1)) - train_dataset = SQuAD(segment='train') vocab_provider = VocabProvider(train_dataset) @@ -506,7 +504,16 @@ def load_transformed_dataset(path): evaluator = PerformanceEvaluator(BiDAFTokenizer(), transformed_dataset, dataset._read_data(), mapper) net = BiDAFModel(word_vocab, char_vocab, args, prefix="bidaf") - net.load_parameters(model_path, ctx=ctx) + + if args.use_exponential_moving_average: + params_path = os.path.join(args.save_dir, + 'ema_epoch{:d}.params'.format(int(args.epochs) - 1)) + net.collect_params().load(params_path, ctx=ctx) + else: + params_path = os.path.join(args.save_dir, + 'epoch{:d}.params'.format(int(args.epochs) - 1)) + net.load_parameters(params_path, ctx=ctx) + net.hybridize(static_alloc=True) result = evaluator.evaluate_performance(net, ctx, args) diff --git a/scripts/question_answering/utils.py b/scripts/question_answering/utils.py index 0780efec0b..f8d3d5f9ea 100644 --- a/scripts/question_answering/utils.py +++ b/scripts/question_answering/utils.py @@ -125,7 +125,8 @@ def get_args(): parser.add_argument('--use_multiprecision_in_optimizer', type=bool, default=False, help='When using float16, shall optimizer use multiprecision.') parser.add_argument('--use_exponential_moving_average', type=bool, default=False, - help='Should averaged copy of parameters been stored.') + help='Should averaged copy of parameters been stored and used ' + 'during evaluation.') parser.add_argument('--exponential_moving_average_weight_decay', type=float, default=0.999, help='Weight decay used in exponential moving average') From ddaac0616de49195a570552f306ab9d9f81e8f43 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Wed, 10 Oct 2018 14:35:20 -0700 Subject: [PATCH 20/43] Caching of vocabs is added --- scripts/question_answering/data_processing.py | 27 ++++++++++++++++--- .../train_question_answering.py | 14 +++++----- scripts/question_answering/utils.py | 4 +++ 3 files changed, 35 insertions(+), 10 deletions(-) diff --git a/scripts/question_answering/data_processing.py b/scripts/question_answering/data_processing.py index 52c600d993..cd49af6c9b 100644 --- a/scripts/question_answering/data_processing.py +++ b/scripts/question_answering/data_processing.py @@ -19,6 +19,10 @@ # pylint: disable= """SQuAD data preprocessing.""" +import pickle + +from os.path import isfile + from gluonnlp.data import SpacyTokenizer from scripts.question_answering.tokenizer import BiDAFTokenizer @@ -209,8 +213,9 @@ def _pad_to_max_word_length(item, max_length): class VocabProvider(object): """Provides word level and character level vocabularies """ - def __init__(self, dataset, tokenizer=BiDAFTokenizer()): + def __init__(self, dataset, options, tokenizer=BiDAFTokenizer()): self._dataset = dataset + self._options = options self._tokenizer = tokenizer def get_tokenizer(self): @@ -225,7 +230,15 @@ def get_char_level_vocab(self): Vocab Character level vocabulary """ - return VocabProvider._create_squad_vocab(iter, self._dataset) + if self._options.char_vocab_path and isfile(self._options.char_vocab_path): + return pickle.load(open(self._options.char_vocab_path, "rb")) + + char_level_vocab = VocabProvider._create_squad_vocab(iter, self._dataset) + + if self._options.char_vocab_path: + pickle.dump(char_level_vocab, open(self._options.char_vocab_path, "wb")) + + return char_level_vocab def get_word_level_vocab(self): """Provides word level vocabulary @@ -236,7 +249,15 @@ def get_word_level_vocab(self): Word level vocabulary """ - return VocabProvider._create_squad_vocab(self._tokenizer, self._dataset) + if self._options.word_vocab_path and isfile(self._options.word_vocab_path): + return pickle.load(open(self._options.word_vocab_path, "rb")) + + word_level_vocab = VocabProvider._create_squad_vocab(self._tokenizer, self._dataset) + + if self._options.word_vocab_path: + pickle.dump(word_level_vocab, open(self._options.word_vocab_path, "wb")) + + return word_level_vocab @staticmethod def _create_squad_vocab(tokenization_fn, dataset): diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index 3dd5d0f9c3..35a4a7e542 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -347,10 +347,10 @@ def decay_gradients(model, ctx, options): for name, parameter in model.collect_params().items(): grad = parameter.grad(ctx) - # if is_fixed_embedding_layer(name): - # grad[0:2] += options.weight_decay * parameter.data(ctx)[0:2] - # else: - # grad += options.weight_decay * parameter.data(ctx) + if is_fixed_embedding_layer(name): + grad[0:2] += options.weight_decay * parameter.data(ctx)[0:2] + else: + grad += options.weight_decay * parameter.data(ctx) gradients.append(grad) return gradients @@ -452,7 +452,7 @@ def load_transformed_dataset(path): print("Running in preprocessing mode") dataset = SQuAD(segment='train') - vocab_provider = VocabProvider(dataset) + vocab_provider = VocabProvider(dataset, args) transformed_dataset = transform_dataset(dataset, vocab_provider, options=args) save_transformed_dataset(transformed_dataset, args.preprocessed_dataset_path) exit(0) @@ -461,7 +461,7 @@ def load_transformed_dataset(path): print("Running in training mode") dataset = SQuAD(segment='train') - vocab_provider = VocabProvider(dataset) + vocab_provider = VocabProvider(dataset, args) mapper = QuestionIdMapper(dataset) if args.preprocessed_dataset_path and isfile(args.preprocessed_dataset_path): @@ -487,7 +487,7 @@ def load_transformed_dataset(path): print("Running in evaluation mode") # we use training dataset to build vocabs train_dataset = SQuAD(segment='train') - vocab_provider = VocabProvider(train_dataset) + vocab_provider = VocabProvider(train_dataset, args) dataset = SQuAD(segment='dev') mapper = QuestionIdMapper(dataset) diff --git a/scripts/question_answering/utils.py b/scripts/question_answering/utils.py index f8d3d5f9ea..8201bb2c92 100644 --- a/scripts/question_answering/utils.py +++ b/scripts/question_answering/utils.py @@ -114,6 +114,10 @@ def get_args(): help='report interval') parser.add_argument('--save_dir', type=str, default='out_dir', help='directory path to save the final model and training log') + parser.add_argument('--word_vocab_path', type=str, default=None, + help='Path to preprocessed word-level vocabulary') + parser.add_argument('--char_vocab_path', type=str, default=None, + help='Path to preprocessed character-level vocabulary') parser.add_argument('--gpu', type=str, default=None, help='Coma-separated ids of the gpu to use. Empty means to use cpu.') parser.add_argument('--precision', type=str, default='float32', choices=['float16', 'float32'], From 8379a2c03379148ac3d2a6977469d3de816ddc5c Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Wed, 10 Oct 2018 15:50:58 -0700 Subject: [PATCH 21/43] Making utils function support FP16 --- scripts/question_answering/bidaf.py | 23 ++++++++++++++++--- .../question_answering/question_answering.py | 3 ++- scripts/question_answering/utils.py | 14 ++++++----- 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/scripts/question_answering/bidaf.py b/scripts/question_answering/bidaf.py index 0f32e1d165..8ee7884599 100644 --- a/scripts/question_answering/bidaf.py +++ b/scripts/question_answering/bidaf.py @@ -18,6 +18,7 @@ # under the License. from mxnet import gluon +import numpy as np from .attention_flow import AttentionFlow from .utils import last_dim_softmax, weighted_sum, replace_masked_values, masked_softmax @@ -36,6 +37,7 @@ def __init__(self, passage_length, question_length, encoding_dim, + precision, **kwargs): super(BidirectionalAttentionFlow, self).__init__(**kwargs) @@ -43,10 +45,23 @@ def __init__(self, self._passage_length = passage_length self._question_length = question_length self._encoding_dim = encoding_dim + self._precision = precision self._matrix_attention = AttentionFlow(attention_similarity_function, batch_size, passage_length, question_length, encoding_dim) + def _get_big_negative_value(self): + if self._precision == 'float16': + return np.finfo(np.float16).min + else: + return np.finfo(np.float32).min + + def _get_small_positive_value(self): + if self._precision == 'float16': + return np.finfo(np.float16).eps + else: + return np.finfo(np.float32).eps + def hybrid_forward(self, F, encoded_passage, encoded_question, question_mask, passage_mask): # pylint: disable=arguments-differ """ @@ -63,7 +78,8 @@ def hybrid_forward(self, F, encoded_passage, encoded_question, question_mask, pa passage_question_similarity, question_mask, passage_question_similarity_shape, - question_mask_shape) + question_mask_shape, + epsilon=self._get_small_positive_value()) # Shape: (batch_size, passage_length, encoding_dim) encoded_question_shape = (self._batch_size, self._question_length, self._encoding_dim) passage_question_attention_shape = (self._batch_size, self._passage_length, @@ -78,13 +94,14 @@ def hybrid_forward(self, F, encoded_passage, encoded_question, question_mask, pa replace_masked_values(F, passage_question_similarity, question_mask.expand_dims(1), - -1e7) + replace_with=self._get_big_negative_value()) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max(axis=-1) # Shape: (batch_size, passage_length) - question_passage_attention = masked_softmax(F, question_passage_similarity, passage_mask) + question_passage_attention = masked_softmax(F, question_passage_similarity, passage_mask, + epsilon=self._get_small_positive_value()) # Shape: (batch_size, encoding_dim) encoded_passage_shape = (self._batch_size, self._passage_length, self._encoding_dim) diff --git a/scripts/question_answering/question_answering.py b/scripts/question_answering/question_answering.py index 7f2d319506..798f853182 100644 --- a/scripts/question_answering/question_answering.py +++ b/scripts/question_answering/question_answering.py @@ -293,7 +293,8 @@ def __init__(self, word_vocab, char_vocab, options, prefix=None, params=None): options.batch_size, options.ctx_max_len, options.q_max_len, - 2 * options.embedding_size) + 2 * options.embedding_size, + options.precision) self.modeling_layer = BiDAFModelingLayer(options.batch_size, input_dim=options.embedding_size, nlayers=options.modeling_num_layers, diff --git a/scripts/question_answering/utils.py b/scripts/question_answering/utils.py index 8201bb2c92..37de235b43 100644 --- a/scripts/question_answering/utils.py +++ b/scripts/question_answering/utils.py @@ -218,7 +218,7 @@ def combine_tensors(combination, tensors): return nd.concat(to_concatenate, dim=-1) -def masked_softmax(F, vector, mask): +def masked_softmax(F, vector, mask, epsilon): """ ``nd.softmax(vector)`` does not work if some elements of ``vector`` should be masked. This performs a softmax on just the non-masked portions of ``vector``. Passing @@ -236,7 +236,7 @@ def masked_softmax(F, vector, mask): # To limit numerical errors from large vector elements outside the mask, we zero these out. result = F.softmax(vector * mask, axis=-1) result = result * mask - result = F.broadcast_div(result, (result.sum(axis=1, keepdims=True) + 1e-13)) + result = F.broadcast_div(result, (result.sum(axis=1, keepdims=True) + epsilon)) return result @@ -273,7 +273,8 @@ def _last_dimension_applicator(F, tensor, mask, tensor_shape, - mask_shape): + mask_shape, + **kwargs): """ Takes a tensor with 3 or more dimensions and applies a function over the last dimension. We assume the tensor has shape ``(batch_size, ..., sequence_length)`` and that the mask (if given) @@ -289,17 +290,18 @@ def _last_dimension_applicator(F, mask = mask.expand_dims(1) mask = mask.broadcast_to(shape=tensor_shape) mask = mask.reshape(shape=(-1, mask_shape[-1])) - reshaped_result = function_to_apply(F, reshaped_tensor, mask) + reshaped_result = function_to_apply(F, reshaped_tensor, mask, **kwargs) return reshaped_result.reshape(shape=tensor_shape) -def last_dim_softmax(F, tensor, mask, tensor_shape, mask_shape): +def last_dim_softmax(F, tensor, mask, tensor_shape, mask_shape, epsilon): """ Takes a tensor with 3 or more dimensions and does a masked softmax over the last dimension. We assume the tensor has shape ``(batch_size, ..., sequence_length)`` and that the mask (if given) has shape ``(batch_size, sequence_length)``. """ - return _last_dimension_applicator(F, masked_softmax, tensor, mask, tensor_shape, mask_shape) + return _last_dimension_applicator(F, masked_softmax, tensor, mask, tensor_shape, mask_shape, + epsilon=epsilon) def last_dim_log_softmax(F, tensor, mask, tensor_shape, mask_shape): From e163925f7a1c65d9ebc20d025ad2878108c93d2b Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Fri, 12 Oct 2018 10:54:47 -0700 Subject: [PATCH 22/43] Dev set also present in vocab --- scripts/question_answering/data_processing.py | 39 +++++++++---- scripts/question_answering/tokenizer.py | 6 +- .../train_question_answering.py | 55 +++++++++++-------- 3 files changed, 62 insertions(+), 38 deletions(-) diff --git a/scripts/question_answering/data_processing.py b/scripts/question_answering/data_processing.py index cd49af6c9b..6e8561e4a7 100644 --- a/scripts/question_answering/data_processing.py +++ b/scripts/question_answering/data_processing.py @@ -23,7 +23,7 @@ from os.path import isfile -from gluonnlp.data import SpacyTokenizer +import gluonnlp as nlp from scripts.question_answering.tokenizer import BiDAFTokenizer __all__ = ['SQuADTransform', 'VocabProvider', 'preprocess_dataset'] @@ -62,8 +62,9 @@ class SQuADTransform(object): """SQuADTransform class responsible for converting text data into NDArrays that can be later feed into DataProvider """ - def __init__(self, vocab_provider, question_max_length, context_max_length, max_chars_per_word): - self._word_vocab = vocab_provider.get_word_level_vocab() + def __init__(self, vocab_provider, question_max_length, context_max_length, + max_chars_per_word, embedding_size): + self._word_vocab = vocab_provider.get_word_level_vocab(embedding_size) self._char_vocab = vocab_provider.get_char_level_vocab() self._tokenizer = vocab_provider.get_tokenizer() @@ -213,8 +214,8 @@ def _pad_to_max_word_length(item, max_length): class VocabProvider(object): """Provides word level and character level vocabularies """ - def __init__(self, dataset, options, tokenizer=BiDAFTokenizer()): - self._dataset = dataset + def __init__(self, datasets, options, tokenizer=BiDAFTokenizer()): + self._datasets = datasets self._options = options self._tokenizer = tokenizer @@ -233,14 +234,18 @@ def get_char_level_vocab(self): if self._options.char_vocab_path and isfile(self._options.char_vocab_path): return pickle.load(open(self._options.char_vocab_path, "rb")) - char_level_vocab = VocabProvider._create_squad_vocab(iter, self._dataset) + all_chars = [] + for dataset in self._datasets: + all_chars.extend(VocabProvider._get_all_tokens(iter, dataset)) + + char_level_vocab = VocabProvider._create_squad_vocab(all_chars) if self._options.char_vocab_path: pickle.dump(char_level_vocab, open(self._options.char_vocab_path, "wb")) return char_level_vocab - def get_word_level_vocab(self): + def get_word_level_vocab(self, embedding_size): """Provides word level vocabulary Returns @@ -252,7 +257,13 @@ def get_word_level_vocab(self): if self._options.word_vocab_path and isfile(self._options.word_vocab_path): return pickle.load(open(self._options.word_vocab_path, "rb")) - word_level_vocab = VocabProvider._create_squad_vocab(self._tokenizer, self._dataset) + all_words = [] + for dataset in self._datasets: + all_words.extend(VocabProvider._get_all_tokens(self._tokenizer, dataset)) + + word_level_vocab = VocabProvider._create_squad_vocab(all_words) + word_level_vocab.set_embedding( + nlp.embedding.create('glove', source='glove.6B.{}d'.format(embedding_size))) if self._options.word_vocab_path: pickle.dump(word_level_vocab, open(self._options.word_vocab_path, "wb")) @@ -260,13 +271,17 @@ def get_word_level_vocab(self): return word_level_vocab @staticmethod - def _create_squad_vocab(tokenization_fn, dataset): + def _create_squad_vocab(all_tokens): + counter = data.count_tokens(all_tokens) + vocab = Vocab(counter) + return vocab + + @staticmethod + def _get_all_tokens(tokenization_fn, dataset): all_tokens = [] for data_item in dataset: all_tokens.extend(tokenization_fn(data_item[2])) all_tokens.extend(tokenization_fn(data_item[3])) - counter = data.count_tokens(all_tokens) - vocab = Vocab(counter) - return vocab + return all_tokens diff --git a/scripts/question_answering/tokenizer.py b/scripts/question_answering/tokenizer.py index a1a561eaf0..91ae4a3fb5 100644 --- a/scripts/question_answering/tokenizer.py +++ b/scripts/question_answering/tokenizer.py @@ -39,13 +39,13 @@ def __call__(self, sample): ret : list of strs List of tokens """ - tokens = [token.replace("''", '"').replace("``", '"') for token in - self._base_tokenizer(sample)] + sample = sample.replace('\'\'', '\" ').replace(r'``', '\" ') + tokens = self._base_tokenizer(sample) if self._lower_case: tokens = [token.lower() for token in tokens] - tokens = BiDAFTokenizer._process_tokens(tokens) + # tokens = BiDAFTokenizer._process_tokens(tokens) return tokens @staticmethod diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index 35a4a7e542..354b0b6daf 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -36,7 +36,6 @@ from mxnet.gluon import Trainer from mxnet.gluon.data import DataLoader, SimpleDataset, ArrayDataset -import gluonnlp as nlp from gluonnlp.data import SQuAD from scripts.question_answering.data_processing import VocabProvider, SQuADTransform @@ -70,7 +69,7 @@ def transform_dataset(dataset, vocab_provider, options): A tuple of dataset, QuestionIdMapper and original json data for evaluation """ transformer = SQuADTransform(vocab_provider, options.q_max_len, - options.ctx_max_len, options.word_max_len) + options.ctx_max_len, options.word_max_len, args.embedding_size) processed_dataset = SimpleDataset([transformer(*record) for i, record in enumerate(dataset)]) return processed_dataset @@ -139,11 +138,7 @@ def get_vocabs(vocab_provider, options): data : Tuple A tuple of word vocabulary and character vocabulary """ - word_vocab = vocab_provider.get_word_level_vocab() - - word_vocab.set_embedding( - nlp.embedding.create('glove', source='glove.6B.{}d'.format(options.embedding_size))) - + word_vocab = vocab_provider.get_word_level_vocab(options.embedding_size) char_vocab = vocab_provider.get_char_level_vocab() return word_vocab, char_vocab @@ -265,6 +260,12 @@ def run_training(net, dataloader, ctx, options): # print("Iteration {}.1".format(iteration)) trainer.set_learning_rate(get_learning_rate_per_iteration(iteration, options)) + + for c in ctx: + gradients = decay_gradients(net, c, options) + gluon.utils.clip_global_norm(gradients, options.clip, check_isfinite=False) + reset_embedding_gradients(net, c) + trainer.step(options.batch_size) # for loss in losses: # loss.asnumpy() @@ -335,7 +336,8 @@ def get_learning_rate_per_iteration(iteration, options): def decay_gradients(model, ctx, options): - """Apply gradient decay to all layers. And only to UNK and EOS embeddings of Glove layers + """Apply gradient decay to all layers. For predefined embedding layers, we train only + OOV token embeddings :param BiDAFModel model: Model in training :param ctx: Contexts @@ -347,8 +349,9 @@ def decay_gradients(model, ctx, options): for name, parameter in model.collect_params().items(): grad = parameter.grad(ctx) + # we train OOV token if is_fixed_embedding_layer(name): - grad[0:2] += options.weight_decay * parameter.data(ctx)[0:2] + grad[0] += options.weight_decay * parameter.data(ctx)[0] else: grad += options.weight_decay * parameter.data(ctx) gradients.append(grad) @@ -358,13 +361,13 @@ def decay_gradients(model, ctx, options): def reset_embedding_gradients(model, ctx): """Gradients for glove layers of both question and context embeddings doesn't need to be - trainer. We train only UNK and EOS embeddings. + trainer. We train only OOV token embedding. :param BiDAFModel model: Model in training :param ctx: Contexts of training """ - model.q_embedding._word_embedding.weight.grad(ctx=ctx)[2:] = 0 - model.ctx_embedding._word_embedding.weight.grad(ctx=ctx)[2:] = 0 + model.q_embedding._word_embedding.weight.grad(ctx=ctx)[1:] = 0 + model.ctx_embedding._word_embedding.weight.grad(ctx=ctx)[1:] = 0 def is_fixed_embedding_layer(name): @@ -451,18 +454,26 @@ def load_transformed_dataset(path): print("Running in preprocessing mode") - dataset = SQuAD(segment='train') - vocab_provider = VocabProvider(dataset, args) - transformed_dataset = transform_dataset(dataset, vocab_provider, options=args) + # we use both datasets to create proper vocab + dataset_train = SQuAD(segment='train') + dataset_dev = SQuAD(segment='dev') + + vocab_provider = VocabProvider([dataset_train, dataset_dev], args) + transformed_dataset = transform_dataset(dataset_train, vocab_provider, options=args) save_transformed_dataset(transformed_dataset, args.preprocessed_dataset_path) + + if args.preprocessed_val_dataset_path: + transformed_dataset = transform_dataset(dataset_dev, vocab_provider, options=args) + save_transformed_dataset(transformed_dataset, args.preprocessed_val_dataset_path) + exit(0) if args.train: print("Running in training mode") dataset = SQuAD(segment='train') - vocab_provider = VocabProvider(dataset, args) - mapper = QuestionIdMapper(dataset) + dataset_val = SQuAD(segment='dev') + vocab_provider = VocabProvider([dataset, dataset_val], args) if args.preprocessed_dataset_path and isfile(args.preprocessed_dataset_path): transformed_dataset = load_transformed_dataset(args.preprocessed_dataset_path) @@ -474,22 +485,20 @@ def load_transformed_dataset(path): word_vocab, char_vocab = get_vocabs(vocab_provider, options=args) ctx = get_context(args) - evaluator = PerformanceEvaluator(BiDAFTokenizer(), transformed_dataset, - dataset._read_data(), mapper) net = BiDAFModel(word_vocab, char_vocab, args, prefix="bidaf") net.cast(args.precision) net.initialize(init.Xavier(magnitude=2.24), ctx=ctx) - net.hybridize(static_alloc=True) + net.hybridize() run_training(net, train_dataloader, ctx, options=args) if args.evaluate: print("Running in evaluation mode") - # we use training dataset to build vocabs - train_dataset = SQuAD(segment='train') - vocab_provider = VocabProvider(train_dataset, args) + train_dataset = SQuAD(segment='train') dataset = SQuAD(segment='dev') + + vocab_provider = VocabProvider([train_dataset, dataset], args) mapper = QuestionIdMapper(dataset) if args.preprocessed_val_dataset_path and isfile(args.preprocessed_val_dataset_path): From e72eb649c7186d9423cf92cb83c1256f99ae06ad Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Fri, 12 Oct 2018 14:47:18 -0700 Subject: [PATCH 23/43] GlobalGradClip seems to work on 4 gpu, 15 items --- scripts/question_answering/data_processing.py | 20 ----- .../question_answering/question_answering.py | 22 +++--- .../train_question_answering.py | 79 +++++++++---------- scripts/tests/test_question_answering.py | 3 +- 4 files changed, 51 insertions(+), 73 deletions(-) diff --git a/scripts/question_answering/data_processing.py b/scripts/question_answering/data_processing.py index 6e8561e4a7..e8ca65fe09 100644 --- a/scripts/question_answering/data_processing.py +++ b/scripts/question_answering/data_processing.py @@ -38,26 +38,6 @@ from gluonnlp.data.batchify import Pad -def preprocess_dataset(dataset, question_max_length, context_max_length, max_chars_per_word): - """Process SQuAD dataset by creating NDArray version of data - - :param Dataset dataset: SQuAD dataset - :param int question_max_length: Maximum length of question (padded or trimmed to that size) - :param int context_max_length: Maximum length of context (padded or trimmed to that size) - :param int max_chars_per_word: Maximum length of word (padded or trimmed to that size) - - Returns - ------- - SimpleDataset - Dataset of preprocessed records - """ - vocab_provider = VocabProvider(dataset) - transformer = SQuADTransform(vocab_provider, question_max_length, - context_max_length, max_chars_per_word) - processed_dataset = SimpleDataset(dataset.trasform(transformer, lazy=False)) - return processed_dataset - - class SQuADTransform(object): """SQuADTransform class responsible for converting text data into NDArrays that can be later feed into DataProvider diff --git a/scripts/question_answering/question_answering.py b/scripts/question_answering/question_answering.py index 798f853182..03da33c7fd 100644 --- a/scripts/question_answering/question_answering.py +++ b/scripts/question_answering/question_answering.py @@ -85,22 +85,17 @@ def begin_state(self, ctx, batch_sizes=None): def hybrid_forward(self, F, w, c, contextual_embedding_state, *args): # Changing shape from NTC to TNC as most MXNet blocks work with TNC format natively - word_level_data = F.transpose(w, axes=(1, 0)) - char_level_data = F.transpose(c, axes=(1, 0, 2)) - # Get word embeddings. Output is batch_size x seq_len x embedding size (100) - word_embedded = self._word_embedding(word_level_data) + word_embedded = self._word_embedding(w) # Get char level embedding in multiple steps: # Step 1. Embed into 8-dim vector - char_level_data = self._char_dense_embedding(char_level_data) + char_level_data = self._char_dense_embedding(c) # Step 2. Transpose to put seq_len first axis to later iterate over it # In that way we can get embedding per token of every batch - char_level_data = F.transpose(char_level_data, axes=(0, 2, 1, 3)) + char_level_data = F.transpose(char_level_data, axes=(1, 2, 0, 3)) - # Step 3. Iterate over tokens of each batch and apply convolutional encoder - # As a result of a single iteration, we get token embedding for every batch def convolute(token_of_all_batches, _): return self._char_conv_embedding(token_of_all_batches), [] @@ -114,12 +109,19 @@ def convolute(token_of_all_batches, _): # self._batch_size, # self._embedding_size)) - # Concat embeddings, making channels size = 200 + # Transpose to TNC, to join + word_embedded = F.transpose(word_embedded, axes=(1, 0, 2)) highway_input = F.concat(char_embedded, word_embedded, dim=2) + def highway(token_of_all_batches, _): + return self._highway_network(token_of_all_batches), [] + + highway_output, _ = F.contrib.foreach(highway, highway_input, []) + # Pass through highway, shape remains unchanged - highway_output = self._highway_network(highway_input) + # highway_output = self._highway_network(highway_input) + # Transpose to TNC - default for LSTM ce_output, ce_state = self._contextual_embedding(highway_output, contextual_embedding_state) return ce_output diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index 354b0b6daf..b9ef4af28f 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -187,7 +187,7 @@ def run_training(net, dataloader, ctx, options): if options.precision == 'float16' and options.use_multiprecision_in_optimizer: hyperparameters["multi_precision"] = True - trainer = Trainer(net.collect_params(), args.optimizer, hyperparameters, kvstore="device") + trainer = Trainer(net.collect_params(), args.optimizer, hyperparameters, kvstore="local") loss_function = SoftmaxCrossEntropyLoss() ema = None @@ -254,48 +254,22 @@ def run_training(net, dataloader, ctx, options): ema = PolyakAveraging(net.collect_params(), args.exponential_moving_average_weight_decay) - # for loss in losses: - # loss.asnumpy() - # - # print("Iteration {}.1".format(iteration)) - trainer.set_learning_rate(get_learning_rate_per_iteration(iteration, options)) + trainer.allreduce_grads() + + gradients = decay_gradients(net, ctx[0], options) + gluon.utils.clip_global_norm(gradients, options.clip, check_isfinite=False) + reset_embedding_gradients(net, ctx[0]) + + for name, parameter in net.collect_params().items(): + grads = parameter.list_grad() + source = grads[0] + destination = grads[1:] - for c in ctx: - gradients = decay_gradients(net, c, options) - gluon.utils.clip_global_norm(gradients, options.clip, check_isfinite=False) - reset_embedding_gradients(net, c) - - trainer.step(options.batch_size) - # for loss in losses: - # loss.asnumpy() - # - # print("Iteration {}.1".format(iteration)) - # trainer.allreduce_grads() - # for loss in losses: - # loss.asnumpy() - # - # print("Iteration {}.2".format(iteration)) - # - # gradients = decay_gradients(net, ctx[0], options) - # - # for gradient in gradients: - # gradient.asnumpy() - # - # print("Iteration {}.3".format(iteration)) - # - # gluon.utils.clip_global_norm(gradients, options.clip, check_isfinite=False) - # reset_embedding_gradients(net, ctx[0]) - # - # for parameter in net.collect_params(): - # grads = parameter.list_grad() - # source = grads[0] - # destination = grads[1:] - # - # for dest in destination: - # source.copyto(dest) - # - # trainer.update(options.batch_size, ignore_stale_grad=True) + for dest in destination: + source.copyto(dest) + + trainer.update(len(ctx) * options.batch_size, ignore_stale_grad=True) if ema is not None: ema.update() @@ -320,6 +294,7 @@ def run_training(net, dataloader, ctx, options): save_model_parameters(net, e, options) save_ema_parameters(ema, e, options) + save_trainer_parameters(trainer, e, options) print("Training time {:6.2f} seconds".format(time() - train_start)) @@ -416,6 +391,28 @@ def save_ema_parameters(ema, epoch, options): ema.get_params().save(save_path) +def save_trainer_parameters(trainer, epoch, options): + """Save exponentially averaged parameters of the trained model + + Parameters + ---------- + trainer : `Trainer` + Trainer + epoch : `int` + Number of epoch + options : `Namespace` + Saving arguments + """ + if trainer is None: + return + + if not os.path.exists(options.save_dir): + os.mkdir(options.save_dir) + + save_path = os.path.join(options.save_dir, 'trainer_epoch{:d}.params'.format(epoch)) + trainer.save_states(save_path) + + def save_transformed_dataset(dataset, path): """Save processed dataset into a file. diff --git a/scripts/tests/test_question_answering.py b/scripts/tests/test_question_answering.py index 7e5ec1b79b..32568d352c 100644 --- a/scripts/tests/test_question_answering.py +++ b/scripts/tests/test_question_answering.py @@ -33,7 +33,6 @@ from scripts.question_answering.data_processing import SQuADTransform, VocabProvider from scripts.question_answering.performance_evaluator import PerformanceEvaluator from scripts.question_answering.question_answering import * -from scripts.question_answering.question_answering import TextIndicesLoss from scripts.question_answering.question_id_mapper import QuestionIdMapper from scripts.question_answering.similarity_function import DotProductSimilarity from scripts.question_answering.tokenizer import BiDAFTokenizer @@ -260,7 +259,7 @@ def test_bidaf_model(): m_layer_begin_state_list = net.modeling_layer.begin_state(ctx) o_layer_begin_state_list = net.output_layer.begin_state(ctx) - loss_function = TextIndicesLoss() + loss_function = SoftmaxCrossEntropyLoss() trainer = Trainer(net.collect_params(), "adadelta", {"learning_rate": 0.5, "multi_precision": True}) From 5742beab9c7cf3847312d3eac8497929882b192a Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Tue, 16 Oct 2018 11:46:43 -0700 Subject: [PATCH 24/43] Bug fixes. EM=39.8, F1=51.965 after 23 epochs --- scripts/question_answering/data_processing.py | 4 +- ...etric.py => official_squad_eval_script.py} | 0 .../performance_evaluator.py | 5 +- .../question_answering/question_answering.py | 54 +++++++----- .../train_question_answering.py | 34 +++++-- scripts/question_answering/utils.py | 7 ++ scripts/tests/test_question_answering.py | 88 +++++++------------ 7 files changed, 106 insertions(+), 86 deletions(-) rename scripts/question_answering/{metric.py => official_squad_eval_script.py} (100%) diff --git a/scripts/question_answering/data_processing.py b/scripts/question_answering/data_processing.py index e8ca65fe09..623ea3ad2a 100644 --- a/scripts/question_answering/data_processing.py +++ b/scripts/question_answering/data_processing.py @@ -26,13 +26,11 @@ import gluonnlp as nlp from scripts.question_answering.tokenizer import BiDAFTokenizer -__all__ = ['SQuADTransform', 'VocabProvider', 'preprocess_dataset'] +__all__ = ['SQuADTransform', 'VocabProvider'] -import re import numpy as np from mxnet import nd -from mxnet.gluon.data import SimpleDataset from gluonnlp import Vocab, data from gluonnlp.data.batchify import Pad diff --git a/scripts/question_answering/metric.py b/scripts/question_answering/official_squad_eval_script.py similarity index 100% rename from scripts/question_answering/metric.py rename to scripts/question_answering/official_squad_eval_script.py diff --git a/scripts/question_answering/performance_evaluator.py b/scripts/question_answering/performance_evaluator.py index 5fab5f14e7..a9bcd75e82 100644 --- a/scripts/question_answering/performance_evaluator.py +++ b/scripts/question_answering/performance_evaluator.py @@ -19,10 +19,9 @@ """Performance evaluator - a proxy class used for plugging in official validation script""" import multiprocessing -import re from mxnet import nd, gluon from mxnet.gluon.data import DataLoader, ArrayDataset -from scripts.question_answering.metric import evaluate +from scripts.question_answering.official_squad_eval_script import evaluate class PerformanceEvaluator: @@ -73,7 +72,7 @@ def evaluate_performance(self, net, ctx, options): ctx_chars = gluon.utils.split_and_load(ctx_chars, ctx, even_split=False) ctx_embedding_begin_state_list = net.ctx_embedding.begin_state(ctx) - q_embedding_begin_state_list = net.q_embedding.begin_state(ctx) + q_embedding_begin_state_list = net.ctx_embedding.begin_state(ctx) m_layer_begin_state_list = net.modeling_layer.begin_state(ctx) o_layer_begin_state_list = net.output_layer.begin_state(ctx) diff --git a/scripts/question_answering/question_answering.py b/scripts/question_answering/question_answering.py index 03da33c7fd..90bce734d8 100644 --- a/scripts/question_answering/question_answering.py +++ b/scripts/question_answering/question_answering.py @@ -20,6 +20,7 @@ """BiDAF model blocks""" from scripts.question_answering.bidaf import BidirectionalAttentionFlow from scripts.question_answering.similarity_function import DotProductSimilarity +from scripts.question_answering.utils import get_very_negative_number __all__ = ['BiDAFEmbedding', 'BiDAFModelingLayer', 'BiDAFOutputLayer', 'BiDAFModel'] @@ -207,22 +208,23 @@ class BiDAFOutputLayer(HybridBlock): params : `ParameterDict` or `None` Shared Parameters for this `Block`. """ - def __init__(self, batch_size, span_start_input_dim=100, units=None, nlayers=1, biflag=True, + def __init__(self, batch_size, span_start_input_dim=100, in_units=None, nlayers=1, biflag=True, dropout=0.2, precision='float32', prefix=None, params=None): super(BiDAFOutputLayer, self).__init__(prefix=prefix, params=params) - units = 4 * span_start_input_dim if units is None else units - embedding_size = 1000 - + in_units = 10 * span_start_input_dim if in_units is None else in_units self._batch_size = batch_size self._precision = precision with self.name_scope(): - self._start_index_dense = nn.Dense(units=units, in_units=units * embedding_size) + self._dropout = nn.Dropout(rate=dropout) + self._start_index_dense = nn.Dense(units=1, in_units=in_units, + use_bias=False, flatten=False) self._end_index_lstm = LSTM(hidden_size=span_start_input_dim, num_layers=nlayers, dropout=dropout, bidirectional=biflag, input_size=2 * span_start_input_dim) - self._end_index_dense = nn.Dense(units=units, in_units=units * embedding_size) + self._end_index_dense = nn.Dense(units=1, in_units=in_units, + use_bias=False, flatten=False) def begin_state(self, ctx, batch_sizes=None): if batch_sizes is None: @@ -234,18 +236,28 @@ def begin_state(self, ctx, batch_sizes=None): batch_sizes)] return state_list - def hybrid_forward(self, F, x, m, state, *args): # pylint: disable=arguments-differ + def hybrid_forward(self, F, x, m, mask, state, *args): # pylint: disable=arguments-differ # setting batch size as the first dimension start_index_input = F.transpose(F.concat(x, m, dim=2), axes=(1, 0, 2)) + start_index_input = self._dropout(start_index_input) + start_index_dense_output = self._start_index_dense(start_index_input) end_index_input_part, _ = self._end_index_lstm(m, state) end_index_input = F.transpose(F.concat(x, end_index_input_part, dim=2), - axes=(1, 0, 2)) + axes=(1, 0, 2)) + end_index_input = self._dropout(end_index_input) end_index_dense_output = self._end_index_dense(end_index_input) + start_index_dense_output = F.squeeze(start_index_dense_output) + start_index_dense_output_masked = start_index_dense_output + ((1 - mask) * + get_very_negative_number()) + + end_index_dense_output = F.squeeze(end_index_dense_output) + end_index_dense_output_masked = end_index_dense_output + ((1 - mask) * + get_very_negative_number()) # Don't need to apply softmax for training, but do need for prediction # Maybe should use autograd properties to check it # Will need to reuse it to actually make predictions @@ -254,7 +266,8 @@ def hybrid_forward(self, F, x, m, state, *args): # pylint: disable=arguments-di # end_index_softmax_output = end_index_dense_output.softmax(axis=1) # end_index = F.argmax(end_index_softmax_output, axis=1) - return start_index_dense_output, end_index_dense_output + return start_index_dense_output_masked, \ + end_index_dense_output_masked # producing output in shape 2 x batch_size x units # output = F.concat(F.expand_dims(start_index_dense_output, axis=0), # F.expand_dims(end_index_dense_output, axis=0), dim=0) @@ -280,15 +293,15 @@ def __init__(self, word_vocab, char_vocab, options, prefix=None, params=None): options.embedding_size, precision=options.precision, prefix="context_embedding") - self.q_embedding = BiDAFEmbedding(options.batch_size, - word_vocab, - char_vocab, - options.q_max_len, - options.ctx_embedding_num_layers, - options.highway_num_layers, - options.embedding_size, - precision=options.precision, - prefix="question_embedding") + # self.q_embedding = BiDAFEmbedding(options.batch_size, + # word_vocab, + # char_vocab, + # options.q_max_len, + # options.ctx_embedding_num_layers, + # options.highway_num_layers, + # options.embedding_size, + # precision=options.precision, + # prefix="question_embedding") # we multiple embedding_size by 2 because we use bidirectional embedding self.attention_layer = BidirectionalAttentionFlow(DotProductSimilarity(), @@ -315,7 +328,8 @@ def hybrid_forward(self, F, qw, cw, qc, cc, output_layer_states=None, *args): ctx_embedding_output = self.ctx_embedding(cw, cc, ctx_embedding_states) - q_embedding_output = self.q_embedding(qw, qc, q_embedding_states) + q_embedding_output = self.ctx_embedding(qw, qc, q_embedding_states) + # self.q_embedding(qw, qc, q_embedding_states) # attention layer expect batch_size x seq_length x channels ctx_embedding_output = F.transpose(ctx_embedding_output, axes=(1, 0, 2)) @@ -334,7 +348,7 @@ def hybrid_forward(self, F, qw, cw, qc, cc, # modeling layer expects seq_length x batch_size x channels modeling_layer_output = self.modeling_layer(attention_layer_output, modeling_layer_states) - output = self.output_layer(attention_layer_output, modeling_layer_output, + output = self.output_layer(attention_layer_output, modeling_layer_output, ctx_mask, output_layer_states) return output diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index b9ef4af28f..1111de5403 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -187,7 +187,8 @@ def run_training(net, dataloader, ctx, options): if options.precision == 'float16' and options.use_multiprecision_in_optimizer: hyperparameters["multi_precision"] = True - trainer = Trainer(net.collect_params(), args.optimizer, hyperparameters, kvstore="local") + trainer = Trainer(net.collect_params(), args.optimizer, hyperparameters, kvstore="device", + update_on_kvstore=False) loss_function = SoftmaxCrossEntropyLoss() ema = None @@ -200,7 +201,7 @@ def run_training(net, dataloader, ctx, options): avg_loss *= 0 # Zero average loss of each epoch ctx_embedding_begin_state_list = net.ctx_embedding.begin_state(ctx) - q_embedding_begin_state_list = net.q_embedding.begin_state(ctx) + q_embedding_begin_state_list = net.ctx_embedding.begin_state(ctx) # net.q_embedding.begin_state(ctx) m_layer_begin_state_list = net.modeling_layer.begin_state(ctx) o_layer_begin_state_list = net.output_layer.begin_state(ctx) @@ -254,11 +255,19 @@ def run_training(net, dataloader, ctx, options): ema = PolyakAveraging(net.collect_params(), args.exponential_moving_average_weight_decay) + # in special mode we collect gradients and apply processing only after + # predefined number of grad_req_add_mode which acts like batch_size counter + if options.grad_req_add_mode > 0: + if not iteration % options.grad_req_add_mode != 0 and \ + iteration != len(dataloader): + iteration += 1 + continue + trainer.set_learning_rate(get_learning_rate_per_iteration(iteration, options)) trainer.allreduce_grads() gradients = decay_gradients(net, ctx[0], options) - gluon.utils.clip_global_norm(gradients, options.clip, check_isfinite=False) + gluon.utils.clip_global_norm(gradients, options.clip, check_isfinite=True) reset_embedding_gradients(net, ctx[0]) for name, parameter in net.collect_params().items(): @@ -269,7 +278,10 @@ def run_training(net, dataloader, ctx, options): for dest in destination: source.copyto(dest) - trainer.update(len(ctx) * options.batch_size, ignore_stale_grad=True) + scailing_coeff = len(ctx) * options.batch_size \ + if options.grad_req_add_mode == 0 else options.grad_req_add_mode + + trainer.update(scailing_coeff, ignore_stale_grad=True) if ema is not None: ema.update() @@ -341,7 +353,7 @@ def reset_embedding_gradients(model, ctx): :param BiDAFModel model: Model in training :param ctx: Contexts of training """ - model.q_embedding._word_embedding.weight.grad(ctx=ctx)[1:] = 0 + # model.q_embedding._word_embedding.weight.grad(ctx=ctx)[1:] = 0 model.ctx_embedding._word_embedding.weight.grad(ctx=ctx)[1:] = 0 @@ -485,7 +497,17 @@ def load_transformed_dataset(path): net = BiDAFModel(word_vocab, char_vocab, args, prefix="bidaf") net.cast(args.precision) net.initialize(init.Xavier(magnitude=2.24), ctx=ctx) - net.hybridize() + net.hybridize(static_alloc=True) + + # total_params = sum( + # v.data().shape[0] if len(v.data().shape) == 1 else v.data().shape[1] + # if v.data().shape[0] == 0 else v.data().shape[0] if + # v.data().shape[1] == 0 else v.data().shape[0] * v.data().shape[1] + # for k, v in net.ctx_embedding._word_embedding.collect_params().items()) + # print('number of params: %d' % total_params) + + if args.grad_req_add_mode: + net.collect_params().setattr('grad_req', 'add') run_training(net, train_dataloader, ctx, options=args) diff --git a/scripts/question_answering/utils.py b/scripts/question_answering/utils.py index 37de235b43..a5625aab82 100644 --- a/scripts/question_answering/utils.py +++ b/scripts/question_answering/utils.py @@ -133,11 +133,18 @@ def get_args(): 'during evaluation.') parser.add_argument('--exponential_moving_average_weight_decay', type=float, default=0.999, help='Weight decay used in exponential moving average') + parser.add_argument('--grad_req_add_mode', type=int, default=0, + help='Enable rolling gradient mode, where batch size is always 1 and ' + 'gradients are accumulated using single GPU') args = parser.parse_args() return args +def get_very_negative_number(): + return -1e30 + + def get_combined_dim(combination, tensor_dims): """ For use with :func:`combine_tensors`. This function computes the resultant dimension when diff --git a/scripts/tests/test_question_answering.py b/scripts/tests/test_question_answering.py index 32568d352c..bd18ea8703 100644 --- a/scripts/tests/test_question_answering.py +++ b/scripts/tests/test_question_answering.py @@ -38,6 +38,7 @@ from scripts.question_answering.tokenizer import BiDAFTokenizer from scripts.question_answering.train_question_answering import get_record_per_answer_span +batch_size = 5 question_max_length = 30 context_max_length = 400 max_chars_per_word = 16 @@ -47,9 +48,9 @@ @pytest.mark.serial def test_transform_to_nd_array(): dataset = SQuAD(segment='dev', root='tests/data/squad') - vocab_provider = VocabProvider(dataset) + vocab_provider = VocabProvider(dataset, get_args(batch_size)) transformer = SQuADTransform(vocab_provider, question_max_length, - context_max_length, max_chars_per_word) + context_max_length, max_chars_per_word, embedding_size) record = dataset[0] transformed_record = transformer(*record) @@ -60,9 +61,9 @@ def test_transform_to_nd_array(): @pytest.mark.serial def test_data_loader_able_to_read(): dataset = SQuAD(segment='dev', root='tests/data/squad') - vocab_provider = VocabProvider(dataset) + vocab_provider = VocabProvider(dataset, get_args(batch_size)) transformer = SQuADTransform(vocab_provider, question_max_length, - context_max_length, max_chars_per_word) + context_max_length, max_chars_per_word, embedding_size) record = dataset[0] processed_dataset = SimpleDataset([transformer(*record)]) @@ -83,19 +84,17 @@ def test_data_loader_able_to_read(): @pytest.mark.serial def test_load_vocabs(): dataset = SQuAD(segment='dev', root='tests/data/squad') - vocab_provider = VocabProvider(dataset) + vocab_provider = VocabProvider(dataset, get_args(batch_size)) - assert vocab_provider.get_word_level_vocab() is not None + assert vocab_provider.get_word_level_vocab(embedding_size) is not None assert vocab_provider.get_char_level_vocab() is not None def test_bidaf_embedding(): - batch_size = 5 - dataset = SQuAD(segment='dev', root='tests/data/squad') - vocab_provider = VocabProvider(dataset) + vocab_provider = VocabProvider(dataset, get_args(batch_size)) transformer = SQuADTransform(vocab_provider, question_max_length, - context_max_length, max_chars_per_word) + context_max_length, max_chars_per_word, embedding_size) # for performance reason, process only batch_size # of records processed_dataset = SimpleDataset([transformer(*record) for i, record in enumerate(dataset) @@ -104,7 +103,7 @@ def test_bidaf_embedding(): # need to remove question id before feeding the data to data loader loadable_data, dataloader = get_record_per_answer_span(processed_dataset, get_args(batch_size)) - word_vocab = vocab_provider.get_word_level_vocab() + word_vocab = vocab_provider.get_word_level_vocab(embedding_size) word_vocab.set_embedding(nlp.embedding.create('glove', source='glove.6B.100d')) char_vocab = vocab_provider.get_char_level_vocab() @@ -116,7 +115,7 @@ def test_bidaf_embedding(): embedding.cast("float16") embedding.initialize(init.Xavier(magnitude=2.24)) embedding.hybridize(static_alloc=True) - state = embedding.begin_state() + state = embedding.begin_state(mx.cpu()) trainer = Trainer(embedding.collect_params(), "sgd", {"learning_rate": 0.1, "multi_precision": True}) @@ -139,8 +138,6 @@ def test_bidaf_embedding(): def test_attention_layer(): - batch_size = 5 - ctx_fake_data = nd.random.uniform(shape=(batch_size, context_max_length, 2 * embedding_size), dtype="float16") @@ -154,7 +151,8 @@ def test_attention_layer(): batch_size, context_max_length, question_max_length, - 2 * embedding_size) + 2 * embedding_size, + "float16") layer.cast("float16") layer.initialize() @@ -167,8 +165,6 @@ def test_attention_layer(): def test_modeling_layer(): - batch_size = 5 - # The modeling layer receive input in a shape of batch_size x T x 8d # T is the sequence length of context which is context_max_length # d is the size of embedding, which is embedding_size @@ -181,7 +177,7 @@ def test_modeling_layer(): layer.cast("float16") layer.initialize() layer.hybridize(static_alloc=True) - state = layer.begin_state() + state = layer.begin_state(mx.cpu()) trainer = Trainer(layer.collect_params(), "sgd", {"learning_rate": "0.1", "multi_precision": True}) @@ -195,8 +191,6 @@ def test_modeling_layer(): def test_output_layer(): - batch_size = 5 - # The output layer receive 2 inputs: the output of Modeling layer (context_max_length, # batch_size, 2 * embedding_size) and the output of Attention flow layer # (batch_size, context_max_length, 8 * embedding_size) @@ -213,7 +207,7 @@ def test_output_layer(): # The model doesn't need to know the hidden states, so I don't hold variables for the states layer.initialize() layer.hybridize(static_alloc=True) - state = layer.begin_state() + state = layer.begin_state(mx.cpu()) trainer = Trainer(layer.collect_params(), "sgd", {"learning_rate": 0.1, "multi_precision": True}) @@ -227,13 +221,13 @@ def test_output_layer(): def test_bidaf_model(): - options = get_args(batch_size=5) + options = get_args(batch_size) ctx = [mx.cpu(0), mx.cpu(1)] dataset = SQuAD(segment='dev', root='tests/data/squad') - vocab_provider = VocabProvider(dataset) + vocab_provider = VocabProvider(dataset, options) transformer = SQuADTransform(vocab_provider, question_max_length, - context_max_length, max_chars_per_word) + context_max_length, max_chars_per_word, embedding_size) # for performance reason, process only batch_size # of records processed_dataset = SimpleDataset([transformer(*record) for i, record in enumerate(dataset) @@ -242,7 +236,7 @@ def test_bidaf_model(): # need to remove question id before feeding the data to data loader loadable_data, dataloader = get_record_per_answer_span(processed_dataset, options) - word_vocab = vocab_provider.get_word_level_vocab() + word_vocab = vocab_provider.get_word_level_vocab(embedding_size) word_vocab.set_embedding(nlp.embedding.create('glove', source='glove.6B.100d')) char_vocab = vocab_provider.get_char_level_vocab() @@ -250,12 +244,12 @@ def test_bidaf_model(): char_vocab=char_vocab, options=options) - #net.cast("float16") + net.cast("float16") net.initialize(init.Xavier(magnitude=2.24)) - #net.hybridize(static_alloc=True) + net.hybridize(static_alloc=True) ctx_embedding_begin_state_list = net.ctx_embedding.begin_state(ctx) - q_embedding_begin_state_list = net.q_embedding.begin_state(ctx) + q_embedding_begin_state_list = net.ctx_embedding.begin_state(ctx) m_layer_begin_state_list = net.modeling_layer.begin_state(ctx) o_layer_begin_state_list = net.output_layer.begin_state(ctx) @@ -265,11 +259,11 @@ def test_bidaf_model(): for i, (data, label) in enumerate(dataloader): record_index, q_words, ctx_words, q_chars, ctx_chars = data - # q_words = q_words.astype("float16") - # ctx_words = ctx_words.astype("float16") - # q_chars = q_chars.astype("float16") - # ctx_chars = ctx_chars.astype("float16") - # label = label.astype("float16") + q_words = q_words.astype("float16") + ctx_words = ctx_words.astype("float16") + q_chars = q_chars.astype("float16") + ctx_chars = ctx_chars.astype("float16") + label = label.astype("float16") record_index = gluon.utils.split_and_load(record_index, ctx, even_split=False) q_words = gluon.utils.split_and_load(q_words, ctx, even_split=False) @@ -307,28 +301,29 @@ def test_bidaf_model(): def test_performance_evaluation(): - options = get_args(batch_size=5) + options = get_args(batch_size) train_dataset = SQuAD(segment='train') - vocab_provider = VocabProvider(train_dataset) + vocab_provider = VocabProvider(train_dataset, options) dataset = SQuAD(segment='dev') mapper = QuestionIdMapper(dataset) transformer = SQuADTransform(vocab_provider, question_max_length, - context_max_length, max_chars_per_word) + context_max_length, max_chars_per_word, embedding_size) # for performance reason, process only batch_size # of records transformed_dataset = SimpleDataset([transformer(*record) for i, record in enumerate(dataset) if i < options.batch_size]) - word_vocab = vocab_provider.get_word_level_vocab() + word_vocab = vocab_provider.get_word_level_vocab(embedding_size) word_vocab.set_embedding(nlp.embedding.create('glove', source='glove.6B.100d')) char_vocab = vocab_provider.get_char_level_vocab() model_path = os.path.join(options.save_dir, 'epoch{:d}.params'.format(int(options.epochs) - 1)) ctx = [mx.cpu()] - evaluator = PerformanceEvaluator(transformed_dataset, dataset._read_data(), mapper) + evaluator = PerformanceEvaluator(BiDAFTokenizer(), transformed_dataset, + dataset._read_data(), mapper) net = BiDAFModel(word_vocab, char_vocab, options, prefix="bidaf") net.hybridize(static_alloc=True) net.load_parameters(model_path, ctx=ctx) @@ -337,21 +332,6 @@ def test_performance_evaluation(): print("Evaluation results on dev dataset: {}".format(result)) -# def test_count_num_of_answer_index_greater_400(): -# counter_more_400 = 0 -# counter_less_400 = 0 -# train_dataset = SQuAD(segment='train') -# -# for item in train_dataset: -# for index in item[5]: -# if index >= 400: -# counter_more_400 += 1 -# else: -# counter_less_400 += 1 -# -# print("Less {}, More {}".format(counter_less_400, counter_more_400)) - - def test_get_answer_spans_exact_match(): tokenizer = BiDAFTokenizer() @@ -436,7 +416,7 @@ def get_args(batch_size): options.q_max_len = question_max_length options.word_max_len = max_chars_per_word options.precision = "float32" - options.epochs = 100 + options.epochs = 12 options.save_dir = "output/" options.filter_long_context = False From bc6b8a7ef0dab69fb3462598ca7931d7c2d0ca13 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Tue, 16 Oct 2018 13:53:27 -0700 Subject: [PATCH 25/43] Evaluation changed and resume training added --- .../performance_evaluator.py | 49 ++++++++++++------- .../train_question_answering.py | 20 +++++++- scripts/question_answering/utils.py | 1 + 3 files changed, 51 insertions(+), 19 deletions(-) diff --git a/scripts/question_answering/performance_evaluator.py b/scripts/question_answering/performance_evaluator.py index a9bcd75e82..f5c340ddd0 100644 --- a/scripts/question_answering/performance_evaluator.py +++ b/scripts/question_answering/performance_evaluator.py @@ -50,6 +50,14 @@ def evaluate_performance(self, net, ctx, options): """ pred = {} + + # Allows to ensure that start index is always <= than end index + for c in ctx: + answer_mask_matrix = nd.zeros(shape=(1, options.ctx_max_len, options.ctx_max_len), ctx=c) + for idx in range(options.answer_max_len): + answer_mask_matrix += nd.eye(N=options.ctx_max_len, M=options.ctx_max_len, + k=idx, ctx=c) + eval_dataset = ArrayDataset([(self._mapper.question_id_to_idx[r[1]], r[2], r[3], r[4], r[5]) for r in self._evaluation_dataset]) eval_dataloader = DataLoader(eval_dataset, batch_size=options.batch_size, @@ -94,14 +102,15 @@ def evaluate_performance(self, net, ctx, options): outs.append((begin, end)) for out in outs: - start_indices = PerformanceEvaluator._get_index(out[0]) - end_indices = PerformanceEvaluator._get_index(out[1]) + start = out[0].softmax(axis=1) + end = out[1].softmax(axis=1) + start_end_span = PerformanceEvaluator._get_indices(start, end, answer_mask_matrix) # iterate over batches - for idx, start, end in zip(data[0], start_indices, end_indices): + for idx, start_end in zip(data[0], start_end_span): idx = int(idx.asscalar()) - start = int(start.asscalar()) - end = int(end.asscalar()) + start = int(start_end[0].asscalar()) + end = int(start_end[1].asscalar()) question_id = self._mapper.idx_to_question_id[idx] pred[question_id] = (start, end, self.get_text_result(idx, (start, end))) @@ -150,19 +159,25 @@ def get_text_result(self, idx, answer_span): return text @staticmethod - def _get_index(prediction): - """Convert prediction to actual index in text + def _get_indices(begin, end, answer_mask_matrix): + r"""Select the begin and end position of answer span. + + At inference time, the predicted span (s, e) is chosen such that + begin_hat[s] * end_hat[e] is maximized and s ≤ e. Parameters ---------- - prediction : `NDArray` - Output of the network - - Returns - ------- - indices : `NDArray` - Indices of a word in context for whole batch + begin : NDArray + input tensor with shape `(batch_size, context_sequence_length)` + end : NDArray + input tensor with shape `(batch_size, context_sequence_length)` """ - indices_softmax_output = prediction.softmax(axis=1) - indices = nd.argmax(indices_softmax_output, axis=1) - return indices + begin_hat = begin.reshape(begin.shape + (1,)) + end_hat = end.reshape(end.shape + (1,)) + end_hat = end_hat.transpose(axes=(0, 2, 1)) + + result = nd.batch_dot(begin_hat, end_hat) * answer_mask_matrix.slice( + begin=(0, 0, 0), end=(1, begin_hat.shape[1], begin_hat.shape[1])) + yp1 = result.max(axis=2).argmax(axis=1, keepdims=True).astype('int32') + yp2 = result.max(axis=1).argmax(axis=1, keepdims=True).astype('int32') + return nd.concat(yp1, yp2, dim=-1) diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index 1111de5403..c68e8a85cb 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -189,6 +189,11 @@ def run_training(net, dataloader, ctx, options): trainer = Trainer(net.collect_params(), args.optimizer, hyperparameters, kvstore="device", update_on_kvstore=False) + + if args.resume_training: + path = os.path.join(options.save_dir, 'trainer_epoch{:d}.params'.format(args.resume_training)) + trainer.load_states(path) + loss_function = SoftmaxCrossEntropyLoss() ema = None @@ -197,11 +202,11 @@ def run_training(net, dataloader, ctx, options): iteration = 1 print("Starting training...") - for e in range(args.epochs): + for e in range(0 if not args.resume_training else args.resume_training, args.epochs): avg_loss *= 0 # Zero average loss of each epoch ctx_embedding_begin_state_list = net.ctx_embedding.begin_state(ctx) - q_embedding_begin_state_list = net.ctx_embedding.begin_state(ctx) # net.q_embedding.begin_state(ctx) + q_embedding_begin_state_list = net.ctx_embedding.begin_state(ctx) m_layer_begin_state_list = net.modeling_layer.begin_state(ctx) o_layer_begin_state_list = net.output_layer.begin_state(ctx) @@ -255,6 +260,11 @@ def run_training(net, dataloader, ctx, options): ema = PolyakAveraging(net.collect_params(), args.exponential_moving_average_weight_decay) + if args.resume_training: + path = os.path.join(options.save_dir, 'ema_epoch{:d}.params'.format( + args.resume_training)) + ema.get_params().load(path) + # in special mode we collect gradients and apply processing only after # predefined number of grad_req_add_mode which acts like batch_size counter if options.grad_req_add_mode > 0: @@ -509,6 +519,12 @@ def load_transformed_dataset(path): if args.grad_req_add_mode: net.collect_params().setattr('grad_req', 'add') + if args.resume_training: + print("Resuming training from {} epoch".format(args.resume_training)) + params_path = os.path.join(args.save_dir, + 'epoch{:d}.params'.format(int(args.resume_training) - 1)) + net.load_parameters(params_path, ctx) + run_training(net, train_dataloader, ctx, options=args) if args.evaluate: diff --git a/scripts/question_answering/utils.py b/scripts/question_answering/utils.py index a5625aab82..bd454259ea 100644 --- a/scripts/question_answering/utils.py +++ b/scripts/question_answering/utils.py @@ -103,6 +103,7 @@ def get_args(): parser.add_argument('--ctx_max_len', type=int, default=400, help='Maximum length of a context') parser.add_argument('--q_max_len', type=int, default=30, help='Maximum length of a question') parser.add_argument('--word_max_len', type=int, default=16, help='Maximum characters in a word') + parser.add_argument('--answer_max_len', type=int, default=30, help='Maximum tokens in answer') parser.add_argument('--optimizer', type=str, default='adadelta', help='optimization algorithm') parser.add_argument('--lr', type=float, default=0.5, help='Initial learning rate') parser.add_argument('--lr_warmup_steps', type=int, default=1000, From bf70d667ce64d6f09c8ceae3d1cf840301813289 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Wed, 17 Oct 2018 09:54:43 -0700 Subject: [PATCH 26/43] Training resuming added --- .../question_answering/train_question_answering.py | 13 +++++++++---- scripts/question_answering/utils.py | 2 ++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index c68e8a85cb..e923db012e 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -191,7 +191,8 @@ def run_training(net, dataloader, ctx, options): update_on_kvstore=False) if args.resume_training: - path = os.path.join(options.save_dir, 'trainer_epoch{:d}.params'.format(args.resume_training)) + path = os.path.join(options.save_dir, + 'trainer_epoch{:d}.params'.format(args.resume_training - 1)) trainer.load_states(path) loss_function = SoftmaxCrossEntropyLoss() @@ -232,7 +233,6 @@ def run_training(net, dataloader, ctx, options): label = gluon.utils.split_and_load(label, ctx, even_split=False) # Wait for completion of previous iteration to avoid unnecessary memory allocation - mx.nd.waitall() losses = [] for ri, qw, cw, qc, cc, l, ctx_embedding_begin_state, \ @@ -250,7 +250,8 @@ def run_training(net, dataloader, ctx, options): m_layer_begin_state, o_layer_begin_state) begin_end = l.split(axis=1, num_outputs=2, squeeze_axis=1) - loss = loss_function(begin, begin_end[0]) + loss_function(end, begin_end[1]) + loss = loss_function(begin, begin_end[0]).mean() + \ + loss_function(end, begin_end[1]).mean() losses.append(loss) for loss in losses: @@ -262,7 +263,7 @@ def run_training(net, dataloader, ctx, options): if args.resume_training: path = os.path.join(options.save_dir, 'ema_epoch{:d}.params'.format( - args.resume_training)) + args.resume_training - 1)) ema.get_params().load(path) # in special mode we collect gradients and apply processing only after @@ -329,6 +330,10 @@ def get_learning_rate_per_iteration(iteration, options): :param NameSpace options: Training options :return float: learning rate """ + + if options.resume_training: + return options.lr + return min(options.lr, options.lr * (math.log(iteration) / math.log(options.lr_warmup_steps))) diff --git a/scripts/question_answering/utils.py b/scripts/question_answering/utils.py index bd454259ea..8a915cd74b 100644 --- a/scripts/question_answering/utils.py +++ b/scripts/question_answering/utils.py @@ -113,6 +113,8 @@ def get_args(): help='Weight decay') parser.add_argument('--log_interval', type=int, default=100, metavar='N', help='report interval') + parser.add_argument('--resume_training', type=int, default=0, + help='Resume training from this epoch number') parser.add_argument('--save_dir', type=str, default='out_dir', help='directory path to save the final model and training log') parser.add_argument('--word_vocab_path', type=str, default=None, From dde6539cc67e49ae1ec6b7cb1ef7f10117761d58 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Fri, 19 Oct 2018 15:22:30 -0700 Subject: [PATCH 27/43] Clean up code --- scripts/question_answering/data_processing.py | 47 ++++++++--- .../question_answering/question_answering.py | 60 ++++---------- .../question_answering/similarity_function.py | 3 +- scripts/question_answering/tokenizer.py | 21 +++-- .../train_question_answering.py | 79 ++++++++++--------- scripts/question_answering/utils.py | 21 +++-- scripts/tests/test_question_answering.py | 28 ++++++- 7 files changed, 149 insertions(+), 110 deletions(-) diff --git a/scripts/question_answering/data_processing.py b/scripts/question_answering/data_processing.py index 623ea3ad2a..873fb24fb9 100644 --- a/scripts/question_answering/data_processing.py +++ b/scripts/question_answering/data_processing.py @@ -58,16 +58,16 @@ def __call__(self, record_index, question_id, question, context, answer_list, Method converts text into numeric arrays based on Vocabulary. Answers are not processed, as they are not needed in input """ - question_tokens = self._tokenizer(question) - context_tokens = self._tokenizer(context) + question_tokens = self._tokenizer(question, lower_case=True) + context_tokens = self._tokenizer(context, lower_case=True) question_words = self._word_vocab[question_tokens[:self._question_max_length]] context_words = self._word_vocab[context_tokens[:self._context_max_length]] - question_chars = [self._char_vocab[list(iter(word))] + question_chars = [self._char_vocab[[character.lower() for character in word]] for word in question_tokens[:self._question_max_length]] - context_chars = [self._char_vocab[list(iter(word))] + context_chars = [self._char_vocab[[character.lower() for character in word]] for word in context_tokens[:self._context_max_length]] question_words_nd = self._pad_to_max_word_length(question_words, self._question_max_length) @@ -139,9 +139,10 @@ def _get_char_indices(text, text_tokens): """ char_indices_per_token = [] current_index = 0 + text_lowered = text.lower() for token in text_tokens: - current_index = text.find(token, current_index) + current_index = text_lowered.find(token, current_index) char_indices_per_token.append((current_index, current_index + len(token))) current_index += len(token) @@ -214,7 +215,7 @@ def get_char_level_vocab(self): all_chars = [] for dataset in self._datasets: - all_chars.extend(VocabProvider._get_all_tokens(iter, dataset)) + all_chars.extend(self._get_all_char_tokens(dataset)) char_level_vocab = VocabProvider._create_squad_vocab(all_chars) @@ -237,12 +238,25 @@ def get_word_level_vocab(self, embedding_size): all_words = [] for dataset in self._datasets: - all_words.extend(VocabProvider._get_all_tokens(self._tokenizer, dataset)) + all_words.extend(self._get_all_word_tokens(dataset)) word_level_vocab = VocabProvider._create_squad_vocab(all_words) word_level_vocab.set_embedding( nlp.embedding.create('glove', source='glove.6B.{}d'.format(embedding_size))) + count = 0 + count2 = 0 + for i in range(len(word_level_vocab)): + if (word_level_vocab.embedding.idx_to_vec[i].sum() != 0).asscalar(): + count += 1 + else: + if count2 < 50: + print(word_level_vocab.embedding.idx_to_token[i]) + count2 += 1 + + print("word_level_vocab {}, word_level_vocab.set_embedding {}".format( + len(word_level_vocab), count)) + if self._options.word_vocab_path: pickle.dump(word_level_vocab, open(self._options.word_vocab_path, "wb")) @@ -254,12 +268,23 @@ def _create_squad_vocab(all_tokens): vocab = Vocab(counter) return vocab - @staticmethod - def _get_all_tokens(tokenization_fn, dataset): + def _get_all_word_tokens(self, dataset): all_tokens = [] for data_item in dataset: - all_tokens.extend(tokenization_fn(data_item[2])) - all_tokens.extend(tokenization_fn(data_item[3])) + all_tokens.extend(self._tokenizer(data_item[2], lower_case=True)) + all_tokens.extend(self._tokenizer(data_item[3], lower_case=True)) + + return all_tokens + + def _get_all_char_tokens(self, dataset): + all_tokens = [] + + for data_item in dataset: + for character in data_item[2]: + all_tokens.extend(character.lower()) + + for character in data_item[3]: + all_tokens.extend(character.lower()) return all_tokens diff --git a/scripts/question_answering/question_answering.py b/scripts/question_answering/question_answering.py index 90bce734d8..b4d75088e4 100644 --- a/scripts/question_answering/question_answering.py +++ b/scripts/question_answering/question_answering.py @@ -51,7 +51,8 @@ def __init__(self, batch_size, word_vocab, char_vocab, max_seq_len, with self.name_scope(): self._char_dense_embedding = nn.Embedding(input_dim=len(char_vocab), - output_dim=8) + output_dim=8, + weight_initializer=initializer.Uniform(0.001)) self._char_conv_embedding = ConvolutionalEncoder( embed_size=8, num_filters=(100,), @@ -70,10 +71,12 @@ def __init__(self, batch_size, word_vocab, char_vocab, max_seq_len, num_layers=contextual_embedding_nlayers, bidirectional=True, input_size=2 * embedding_size) - def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False, force_reinit=False): - super(BiDAFEmbedding, self).initialize(init, ctx, verbose, force_reinit) + def init_embeddings(self, lock_gradients): self._word_embedding.weight.set_data(self._word_vocab.embedding.idx_to_vec) + if lock_gradients: + self._word_embedding.collect_params().setattr('grad_req', 'null') + def begin_state(self, ctx, batch_sizes=None): if batch_sizes is None: batch_sizes = [self._batch_size] * len(ctx) @@ -85,16 +88,10 @@ def begin_state(self, ctx, batch_sizes=None): return state_list def hybrid_forward(self, F, w, c, contextual_embedding_state, *args): - # Changing shape from NTC to TNC as most MXNet blocks work with TNC format natively - # Get word embeddings. Output is batch_size x seq_len x embedding size (100) word_embedded = self._word_embedding(w) - - # Get char level embedding in multiple steps: - # Step 1. Embed into 8-dim vector char_level_data = self._char_dense_embedding(c) - # Step 2. Transpose to put seq_len first axis to later iterate over it - # In that way we can get embedding per token of every batch + # Transpose to put seq_len first axis to iterate over it char_level_data = F.transpose(char_level_data, axes=(1, 2, 0, 3)) def convolute(token_of_all_batches, _): @@ -102,25 +99,15 @@ def convolute(token_of_all_batches, _): char_embedded, _ = F.contrib.foreach(convolute, char_level_data, []) - # Step 4. Concat all tokens embeddings to create a single tensor. - # char_embedded = F.concat(*token_list, dim=0) - - # Step 5. Reshape tensor to match dimensions of embedded words - # char_embedded = char_embedded.reshape(shape=(self._max_seq_len, - # self._batch_size, - # self._embedding_size)) - - # Transpose to TNC, to join + # Transpose to TNC, to join with character embedding word_embedded = F.transpose(word_embedded, axes=(1, 0, 2)) highway_input = F.concat(char_embedded, word_embedded, dim=2) def highway(token_of_all_batches, _): return self._highway_network(token_of_all_batches), [] - highway_output, _ = F.contrib.foreach(highway, highway_input, []) - # Pass through highway, shape remains unchanged - # highway_output = self._highway_network(highway_input) + highway_output, _ = F.contrib.foreach(highway, highway_input, []) # Transpose to TNC - default for LSTM ce_output, ce_state = self._contextual_embedding(highway_output, @@ -258,22 +245,9 @@ def hybrid_forward(self, F, x, m, mask, state, *args): # pylint: disable=argume end_index_dense_output = F.squeeze(end_index_dense_output) end_index_dense_output_masked = end_index_dense_output + ((1 - mask) * get_very_negative_number()) - # Don't need to apply softmax for training, but do need for prediction - # Maybe should use autograd properties to check it - # Will need to reuse it to actually make predictions - # start_index_softmax_output = start_index_dense_output.softmax(axis=1) - # start_index = F.argmax(start_index_softmax_output, axis=1) - # end_index_softmax_output = end_index_dense_output.softmax(axis=1) - # end_index = F.argmax(end_index_softmax_output, axis=1) return start_index_dense_output_masked, \ end_index_dense_output_masked - # producing output in shape 2 x batch_size x units - # output = F.concat(F.expand_dims(start_index_dense_output, axis=0), - # F.expand_dims(end_index_dense_output, axis=0), dim=0) - - # transposing it to batch_size x 2 x units - # return F.transpose(output, axes=(1, 0, 2)) class BiDAFModel(HybridBlock): @@ -284,6 +258,7 @@ def __init__(self, word_vocab, char_vocab, options, prefix=None, params=None): self._options = options with self.name_scope(): + # contextual embedding layer self.ctx_embedding = BiDAFEmbedding(options.batch_size, word_vocab, char_vocab, @@ -293,15 +268,6 @@ def __init__(self, word_vocab, char_vocab, options, prefix=None, params=None): options.embedding_size, precision=options.precision, prefix="context_embedding") - # self.q_embedding = BiDAFEmbedding(options.batch_size, - # word_vocab, - # char_vocab, - # options.q_max_len, - # options.ctx_embedding_num_layers, - # options.highway_num_layers, - # options.embedding_size, - # precision=options.precision, - # prefix="question_embedding") # we multiple embedding_size by 2 because we use bidirectional embedding self.attention_layer = BidirectionalAttentionFlow(DotProductSimilarity(), @@ -321,6 +287,11 @@ def __init__(self, word_vocab, char_vocab, options, prefix=None, params=None): dropout=options.dropout, precision=options.precision) + def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False, + force_reinit=False): + super(BiDAFModel, self).initialize(init, ctx, verbose, force_reinit) + self.ctx_embedding.init_embeddings(not self._options.train_unk_token) + def hybrid_forward(self, F, qw, cw, qc, cc, ctx_embedding_states=None, q_embedding_states=None, @@ -329,7 +300,6 @@ def hybrid_forward(self, F, qw, cw, qc, cc, *args): ctx_embedding_output = self.ctx_embedding(cw, cc, ctx_embedding_states) q_embedding_output = self.ctx_embedding(qw, qc, q_embedding_states) - # self.q_embedding(qw, qc, q_embedding_states) # attention layer expect batch_size x seq_length x channels ctx_embedding_output = F.transpose(ctx_embedding_output, axes=(1, 0, 2)) diff --git a/scripts/question_answering/similarity_function.py b/scripts/question_answering/similarity_function.py index ec4900869b..0b52e24bf7 100644 --- a/scripts/question_answering/similarity_function.py +++ b/scripts/question_answering/similarity_function.py @@ -66,8 +66,7 @@ def hybrid_forward(self, F, array_1, array_2): result = (array_1 * array_2).sum(axis=-1) if self._scale_output: - # result *= F.sqrt(array_1.shape[-1]) - result *= F.contrib.div_sqrt_dim(array_1) + result *= F.sqrt(array_1.shape[-1]) return result diff --git a/scripts/question_answering/tokenizer.py b/scripts/question_answering/tokenizer.py index 91ae4a3fb5..49ca53f406 100644 --- a/scripts/question_answering/tokenizer.py +++ b/scripts/question_answering/tokenizer.py @@ -22,11 +22,10 @@ class BiDAFTokenizer: - def __init__(self, base_tokenizer=SpacyTokenizer(), lower_case=False): + def __init__(self, base_tokenizer=SpacyTokenizer()): self._base_tokenizer = base_tokenizer - self._lower_case = lower_case - def __call__(self, sample): + def __call__(self, sample, lower_case=False): """ Parameters @@ -39,13 +38,14 @@ def __call__(self, sample): ret : list of strs List of tokens """ - sample = sample.replace('\'\'', '\" ').replace(r'``', '\" ') + sample = sample.replace('\'\'', '\" ').replace(r'``', '\" ')\ + .replace(u'\u000A', ' ').replace(u'\u00A0', ' ') tokens = self._base_tokenizer(sample) + tokens = BiDAFTokenizer._process_tokens(tokens) - if self._lower_case: + if lower_case: tokens = [token.lower() for token in tokens] - # tokens = BiDAFTokenizer._process_tokens(tokens) return tokens @staticmethod @@ -58,3 +58,12 @@ def _process_tokens(temp_tokens): tokens.extend(re.split("([{}])".format("".join(splitters)), token)) return tokens + + @staticmethod + def _isascii(token): + try: + token.encode('ascii') + return True + + except UnicodeEncodeError: + return False diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index e923db012e..44cbb03fa1 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -187,6 +187,9 @@ def run_training(net, dataloader, ctx, options): if options.precision == 'float16' and options.use_multiprecision_in_optimizer: hyperparameters["multi_precision"] = True + if options.rho: + hyperparameters["rho"] = options.rho + trainer = Trainer(net.collect_params(), args.optimizer, hyperparameters, kvstore="device", update_on_kvstore=False) @@ -232,7 +235,6 @@ def run_training(net, dataloader, ctx, options): ctx_chars = gluon.utils.split_and_load(ctx_chars, ctx, even_split=False) label = gluon.utils.split_and_load(label, ctx, even_split=False) - # Wait for completion of previous iteration to avoid unnecessary memory allocation losses = [] for ri, qw, cw, qc, cc, l, ctx_embedding_begin_state, \ @@ -250,8 +252,8 @@ def run_training(net, dataloader, ctx, options): m_layer_begin_state, o_layer_begin_state) begin_end = l.split(axis=1, num_outputs=2, squeeze_axis=1) - loss = loss_function(begin, begin_end[0]).mean() + \ - loss_function(end, begin_end[1]).mean() + loss = loss_function(begin, begin_end[0]) + \ + loss_function(end, begin_end[1]) losses.append(loss) for loss in losses: @@ -274,25 +276,35 @@ def run_training(net, dataloader, ctx, options): iteration += 1 continue - trainer.set_learning_rate(get_learning_rate_per_iteration(iteration, options)) - trainer.allreduce_grads() + scailing_coeff = len(ctx) * options.batch_size \ + if options.grad_req_add_mode == 0 else options.grad_req_add_mode + + if options.lr_warmup_steps: + trainer.set_learning_rate(get_learning_rate_per_iteration(iteration, options)) - gradients = decay_gradients(net, ctx[0], options) - gluon.utils.clip_global_norm(gradients, options.clip, check_isfinite=True) - reset_embedding_gradients(net, ctx[0]) + if options.clip or options.train_unk_token: + trainer.allreduce_grads() + gradients = get_gradients(net, ctx[0], options) - for name, parameter in net.collect_params().items(): - grads = parameter.list_grad() - source = grads[0] - destination = grads[1:] + if options.clip: + gluon.utils.clip_global_norm(gradients, options.clip, check_isfinite=True) - for dest in destination: - source.copyto(dest) + if options.train_unk_token: + reset_embedding_gradients(net, ctx[0]) - scailing_coeff = len(ctx) * options.batch_size \ - if options.grad_req_add_mode == 0 else options.grad_req_add_mode + if len(ctx) > 1: + # in multi gpu mode we propagate new gradients to the rest of gpus + for name, parameter in net.collect_params().items(): + grads = parameter.list_grad() + source = grads[0] + destination = grads[1:] - trainer.update(scailing_coeff, ignore_stale_grad=True) + for dest in destination: + source.copyto(dest) + + trainer.update(scailing_coeff) + else: + trainer.step(scailing_coeff) if ema is not None: ema.update() @@ -337,9 +349,8 @@ def get_learning_rate_per_iteration(iteration, options): return min(options.lr, options.lr * (math.log(iteration) / math.log(options.lr_warmup_steps))) -def decay_gradients(model, ctx, options): - """Apply gradient decay to all layers. For predefined embedding layers, we train only - OOV token embeddings +def get_gradients(model, ctx, options): + """Get gradients and apply gradient decay to all layers if required. :param BiDAFModel model: Model in training :param ctx: Contexts @@ -349,26 +360,29 @@ def decay_gradients(model, ctx, options): gradients = [] for name, parameter in model.collect_params().items(): + if is_fixed_embedding_layer(name) and not options.train_unk_token: + continue + grad = parameter.grad(ctx) - # we train OOV token - if is_fixed_embedding_layer(name): - grad[0] += options.weight_decay * parameter.data(ctx)[0] - else: - grad += options.weight_decay * parameter.data(ctx) + if options.weight_decay: + if is_fixed_embedding_layer(name): + grad[0] += options.weight_decay * parameter.data(ctx)[0] + else: + grad += options.weight_decay * parameter.data(ctx) + gradients.append(grad) return gradients def reset_embedding_gradients(model, ctx): - """Gradients for glove layers of both question and context embeddings doesn't need to be - trainer. We train only OOV token embedding. + """Gradients for of embedding layer doesn't need to be trained. + We train only UNK token of embedding if required. :param BiDAFModel model: Model in training :param ctx: Contexts of training """ - # model.q_embedding._word_embedding.weight.grad(ctx=ctx)[1:] = 0 model.ctx_embedding._word_embedding.weight.grad(ctx=ctx)[1:] = 0 @@ -490,8 +504,6 @@ def load_transformed_dataset(path): transformed_dataset = transform_dataset(dataset_dev, vocab_provider, options=args) save_transformed_dataset(transformed_dataset, args.preprocessed_val_dataset_path) - exit(0) - if args.train: print("Running in training mode") @@ -514,13 +526,6 @@ def load_transformed_dataset(path): net.initialize(init.Xavier(magnitude=2.24), ctx=ctx) net.hybridize(static_alloc=True) - # total_params = sum( - # v.data().shape[0] if len(v.data().shape) == 1 else v.data().shape[1] - # if v.data().shape[0] == 0 else v.data().shape[0] if - # v.data().shape[1] == 0 else v.data().shape[0] * v.data().shape[1] - # for k, v in net.ctx_embedding._word_embedding.collect_params().items()) - # print('number of params: %d' % total_params) - if args.grad_req_add_mode: net.collect_params().setattr('grad_req', 'add') diff --git a/scripts/question_answering/utils.py b/scripts/question_answering/utils.py index 8a915cd74b..bd54d6ff1b 100644 --- a/scripts/question_answering/utils.py +++ b/scripts/question_answering/utils.py @@ -80,12 +80,13 @@ def get_args(): parser = argparse.ArgumentParser(description='Question Answering example using BiDAF & SQuAD') parser.add_argument('--preprocess', type=bool, default=False, help='Preprocess dataset only') parser.add_argument('--train', type=bool, default=False, help='Run training') - parser.add_argument('--evaluate', type=bool, default=False, help='Run evaluation on dev dataset') + parser.add_argument('--evaluate', type=bool, default=False, + help='Run evaluation on dev dataset') parser.add_argument('--preprocessed_dataset_path', type=str, default="preprocessed_dataset.p", help='Path to preprocessed dataset') parser.add_argument('--preprocessed_val_dataset_path', type=str, - default="preprocessed_val_dataset.p", help='Path to preprocessed ' - 'validation dataset') + default="preprocessed_val_dataset.p", + help='Path to preprocessed validation dataset') parser.add_argument('--epochs', type=int, default=12, help='Upper epoch limit') parser.add_argument('--embedding_size', type=int, default=100, help='Dimension of the word embedding') @@ -99,17 +100,19 @@ def get_args(): help='Number of layers in Modeling layer of BiDAF') parser.add_argument('--output_num_layers', type=int, default=1, help='Number of layers in Output layer of BiDAF') - parser.add_argument('--batch_size', type=int, default=128, help='Batch size') + parser.add_argument('--batch_size', type=int, default=60, help='Batch size') parser.add_argument('--ctx_max_len', type=int, default=400, help='Maximum length of a context') parser.add_argument('--q_max_len', type=int, default=30, help='Maximum length of a question') parser.add_argument('--word_max_len', type=int, default=16, help='Maximum characters in a word') parser.add_argument('--answer_max_len', type=int, default=30, help='Maximum tokens in answer') parser.add_argument('--optimizer', type=str, default='adadelta', help='optimization algorithm') parser.add_argument('--lr', type=float, default=0.5, help='Initial learning rate') - parser.add_argument('--lr_warmup_steps', type=int, default=1000, + parser.add_argument('--rho', type=float, default=0.9, + help='Adadelta decay rate for both squared gradients and delta.') + parser.add_argument('--lr_warmup_steps', type=int, default=0, help='Defines how many iterations to spend on warming up learning rate') - parser.add_argument('--clip', type=float, default=5.0, help='gradient clipping') - parser.add_argument('--weight_decay', type=float, default=3e-7, + parser.add_argument('--clip', type=float, default=0, help='gradient clipping') + parser.add_argument('--weight_decay', type=float, default=0, help='Weight decay') parser.add_argument('--log_interval', type=int, default=100, metavar='N', help='report interval') @@ -123,9 +126,11 @@ def get_args(): help='Path to preprocessed character-level vocabulary') parser.add_argument('--gpu', type=str, default=None, help='Coma-separated ids of the gpu to use. Empty means to use cpu.') + parser.add_argument('--train_unk_token', type=bool, default=False, + help='Should train unknown token of embedding') parser.add_argument('--precision', type=str, default='float32', choices=['float16', 'float32'], help='Use float16 or float32 precision') - parser.add_argument('--filter_long_context', type=bool, default='True', + parser.add_argument('--filter_long_context', type=bool, default=True, help='Filter contexts if the answer is after ctx_max_len') parser.add_argument('--save_prediction_path', type=str, default='', help='Path to save predictions') diff --git a/scripts/tests/test_question_answering.py b/scripts/tests/test_question_answering.py index bd18ea8703..e9dc1c5c32 100644 --- a/scripts/tests/test_question_answering.py +++ b/scripts/tests/test_question_answering.py @@ -22,7 +22,7 @@ import mxnet as mx from mxnet import init, nd, autograd, gluon -from mxnet.gluon import Trainer +from mxnet.gluon import Trainer, nn from mxnet.gluon.data import DataLoader, SimpleDataset from mxnet.gluon.loss import SoftmaxCrossEntropyLoss from types import SimpleNamespace @@ -31,6 +31,7 @@ from gluonnlp.data import SQuAD from scripts.question_answering.bidaf import BidirectionalAttentionFlow from scripts.question_answering.data_processing import SQuADTransform, VocabProvider +from scripts.question_answering.exponential_moving_average import PolyakAveraging from scripts.question_answering.performance_evaluator import PerformanceEvaluator from scripts.question_answering.question_answering import * from scripts.question_answering.question_id_mapper import QuestionIdMapper @@ -401,6 +402,31 @@ def test_get_char_indices(): assert len(result) == len(context_tokens) +def test_polyak_averaging(): + net = nn.HybridSequential() + net.add(nn.Dense(5), nn.Dense(3), nn.Dense(2)) + net.initialize(init.Xavier()) + # net.hybridize() + + ema = None + loss_fn = SoftmaxCrossEntropyLoss() + trainer = Trainer(net.collect_params(), "sgd", {"learning_rate": 0.5}) + + train_data = mx.random.uniform(-0.1, 0.1, shape=(5, 10)) + train_label = mx.nd.array([0, 1, 1, 0, 1]) + + for i in range(3): + with autograd.record(): + o = net(train_data) + loss = loss_fn(o, train_label) + + if i == 0: + ema = PolyakAveraging(net.collect_params(), decay=0.999) + + loss.backward() + trainer.step(5) + ema.update() + def get_args(batch_size): options = SimpleNamespace() options.gpu = None From a1fcdccba06463a4a3a06b4c8f24d01113f9e89c Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Fri, 19 Oct 2018 17:23:47 -0700 Subject: [PATCH 28/43] Parameters parsing fixes --- .../train_question_answering.py | 6 +++--- scripts/question_answering/utils.py | 18 ++++++++++-------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index 44cbb03fa1..4565684872 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -287,7 +287,7 @@ def run_training(net, dataloader, ctx, options): gradients = get_gradients(net, ctx[0], options) if options.clip: - gluon.utils.clip_global_norm(gradients, options.clip, check_isfinite=True) + gluon.utils.clip_global_norm(gradients, options.clip) if options.train_unk_token: reset_embedding_gradients(net, ctx[0]) @@ -377,8 +377,8 @@ def get_gradients(model, ctx, options): def reset_embedding_gradients(model, ctx): - """Gradients for of embedding layer doesn't need to be trained. - We train only UNK token of embedding if required. + """Gradients for embedding layer doesn't need to be trained. We train only UNK token of + embedding if required. :param BiDAFModel model: Model in training :param ctx: Contexts of training diff --git a/scripts/question_answering/utils.py b/scripts/question_answering/utils.py index bd54d6ff1b..e52e237c1a 100644 --- a/scripts/question_answering/utils.py +++ b/scripts/question_answering/utils.py @@ -78,9 +78,11 @@ def get_args(): """Get console arguments """ parser = argparse.ArgumentParser(description='Question Answering example using BiDAF & SQuAD') - parser.add_argument('--preprocess', type=bool, default=False, help='Preprocess dataset only') - parser.add_argument('--train', type=bool, default=False, help='Run training') - parser.add_argument('--evaluate', type=bool, default=False, + parser.add_argument('--preprocess', default=False, action='store_true', + help='Preprocess dataset') + parser.add_argument('--train', default=False, action='store_true', + help='Run training') + parser.add_argument('--evaluate', default=False, action='store_true', help='Run evaluation on dev dataset') parser.add_argument('--preprocessed_dataset_path', type=str, default="preprocessed_dataset.p", help='Path to preprocessed dataset') @@ -113,7 +115,7 @@ def get_args(): help='Defines how many iterations to spend on warming up learning rate') parser.add_argument('--clip', type=float, default=0, help='gradient clipping') parser.add_argument('--weight_decay', type=float, default=0, - help='Weight decay') + help='Weight decay for parameter updates') parser.add_argument('--log_interval', type=int, default=100, metavar='N', help='report interval') parser.add_argument('--resume_training', type=int, default=0, @@ -126,17 +128,17 @@ def get_args(): help='Path to preprocessed character-level vocabulary') parser.add_argument('--gpu', type=str, default=None, help='Coma-separated ids of the gpu to use. Empty means to use cpu.') - parser.add_argument('--train_unk_token', type=bool, default=False, + parser.add_argument('--train_unk_token', default=False, action='store_true', help='Should train unknown token of embedding') parser.add_argument('--precision', type=str, default='float32', choices=['float16', 'float32'], help='Use float16 or float32 precision') - parser.add_argument('--filter_long_context', type=bool, default=True, + parser.add_argument('--filter_long_context', default=True, action='store_false', help='Filter contexts if the answer is after ctx_max_len') parser.add_argument('--save_prediction_path', type=str, default='', help='Path to save predictions') - parser.add_argument('--use_multiprecision_in_optimizer', type=bool, default=False, + parser.add_argument('--use_multiprecision_in_optimizer', default=True, action='store_false', help='When using float16, shall optimizer use multiprecision.') - parser.add_argument('--use_exponential_moving_average', type=bool, default=False, + parser.add_argument('--use_exponential_moving_average', default=True, action='store_false', help='Should averaged copy of parameters been stored and used ' 'during evaluation.') parser.add_argument('--exponential_moving_average_weight_decay', type=float, default=0.999, From b3d95e66ed534f7aebe1f110770bbd9f46c34308 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Mon, 22 Oct 2018 10:55:39 -0700 Subject: [PATCH 29/43] Early stop is added --- scripts/question_answering/data_processing.py | 18 +- .../performance_evaluator.py | 17 +- .../question_answering/question_answering.py | 41 ++-- scripts/question_answering/tokenizer.py | 9 - .../train_question_answering.py | 206 +++++++++++------- scripts/question_answering/utils.py | 8 +- 6 files changed, 165 insertions(+), 134 deletions(-) diff --git a/scripts/question_answering/data_processing.py b/scripts/question_answering/data_processing.py index 873fb24fb9..9356d39424 100644 --- a/scripts/question_answering/data_processing.py +++ b/scripts/question_answering/data_processing.py @@ -19,6 +19,7 @@ # pylint: disable= """SQuAD data preprocessing.""" +import logging import pickle from os.path import isfile @@ -106,7 +107,7 @@ def _get_answer_spans(context, context_tokens, answer_list, answer_start_list): # 2.1 Find char index range for the answer (not tokenized) # 2.2 Find Context token indices which char indices contains answer char indices # 2.3. Return first and last token indices - context_char_indices = SQuADTransform._get_char_indices(context, context_tokens) + context_char_indices = SQuADTransform.get_char_indices(context, context_tokens) for answer_start_char_index, answer in zip(answer_start_list, answer_list): answer_token_indices = [] @@ -130,7 +131,7 @@ def _get_answer_spans(context, context_tokens, answer_list, answer_start_list): return answer_spans @staticmethod - def _get_char_indices(text, text_tokens): + def get_char_indices(text, text_tokens): """Match token with character indices :param str text: Text @@ -244,19 +245,6 @@ def get_word_level_vocab(self, embedding_size): word_level_vocab.set_embedding( nlp.embedding.create('glove', source='glove.6B.{}d'.format(embedding_size))) - count = 0 - count2 = 0 - for i in range(len(word_level_vocab)): - if (word_level_vocab.embedding.idx_to_vec[i].sum() != 0).asscalar(): - count += 1 - else: - if count2 < 50: - print(word_level_vocab.embedding.idx_to_token[i]) - count2 += 1 - - print("word_level_vocab {}, word_level_vocab.set_embedding {}".format( - len(word_level_vocab), count)) - if self._options.word_vocab_path: pickle.dump(word_level_vocab, open(self._options.word_vocab_path, "wb")) diff --git a/scripts/question_answering/performance_evaluator.py b/scripts/question_answering/performance_evaluator.py index f5c340ddd0..aeeb5aea1c 100644 --- a/scripts/question_answering/performance_evaluator.py +++ b/scripts/question_answering/performance_evaluator.py @@ -21,6 +21,8 @@ import multiprocessing from mxnet import nd, gluon from mxnet.gluon.data import DataLoader, ArrayDataset + +from scripts.question_answering.data_processing import SQuADTransform from scripts.question_answering.official_squad_eval_script import evaluate @@ -145,17 +147,12 @@ def get_text_result(self, idx, answer_span): question_id = self._mapper.idx_to_question_id[idx] context = self._mapper.question_id_to_context[question_id] - context_tokens = self._tokenizer(context) - - # start index is above the context length - return cannot provide an answer - if start > len(context_tokens) - 1: - return '' - - # end index is above the context length - let's take answer to the end of the context - if end > len(context_tokens) - 1: - end = len(context_tokens) - 1 + context_tokens = self._tokenizer(context, lower_case=True) + indices = SQuADTransform.get_char_indices(context, context_tokens) - text = ' '.join(context_tokens[start:end + 1]) + # get text from cutting string from the initial context + # because tokens are hard to combine together + text = context[indices[start][0]:indices[end][1]] return text @staticmethod diff --git a/scripts/question_answering/question_answering.py b/scripts/question_answering/question_answering.py index b4d75088e4..3037e96347 100644 --- a/scripts/question_answering/question_answering.py +++ b/scripts/question_answering/question_answering.py @@ -40,7 +40,7 @@ class BiDAFEmbedding(HybridBlock): """ def __init__(self, batch_size, word_vocab, char_vocab, max_seq_len, contextual_embedding_nlayers=2, highway_nlayers=2, embedding_size=100, - precision='float32', prefix=None, params=None): + dropout=0.2, precision='float32', prefix=None, params=None): super(BiDAFEmbedding, self).__init__(prefix=prefix, params=params) self._word_vocab = word_vocab @@ -53,6 +53,7 @@ def __init__(self, batch_size, word_vocab, char_vocab, max_seq_len, self._char_dense_embedding = nn.Embedding(input_dim=len(char_vocab), output_dim=8, weight_initializer=initializer.Uniform(0.001)) + self._dropout = nn.Dropout(rate=dropout) self._char_conv_embedding = ConvolutionalEncoder( embed_size=8, num_filters=(100,), @@ -69,7 +70,8 @@ def __init__(self, batch_size, word_vocab, char_vocab, max_seq_len, self._highway_network = Highway(2 * embedding_size, num_layers=highway_nlayers) self._contextual_embedding = LSTM(hidden_size=embedding_size, num_layers=contextual_embedding_nlayers, - bidirectional=True, input_size=2 * embedding_size) + bidirectional=True, input_size=2 * embedding_size, + dropout=dropout) def init_embeddings(self, lock_gradients): self._word_embedding.weight.set_data(self._word_vocab.embedding.idx_to_vec) @@ -90,6 +92,7 @@ def begin_state(self, ctx, batch_sizes=None): def hybrid_forward(self, F, w, c, contextual_embedding_state, *args): word_embedded = self._word_embedding(w) char_level_data = self._char_dense_embedding(c) + char_level_data = self._dropout(char_level_data) # Transpose to put seq_len first axis to iterate over it char_level_data = F.transpose(char_level_data, axes=(1, 2, 0, 3)) @@ -195,23 +198,26 @@ class BiDAFOutputLayer(HybridBlock): params : `ParameterDict` or `None` Shared Parameters for this `Block`. """ - def __init__(self, batch_size, span_start_input_dim=100, in_units=None, nlayers=1, biflag=True, + def __init__(self, batch_size, span_start_input_dim=100, nlayers=1, biflag=True, dropout=0.2, precision='float32', prefix=None, params=None): super(BiDAFOutputLayer, self).__init__(prefix=prefix, params=params) - in_units = 10 * span_start_input_dim if in_units is None else in_units self._batch_size = batch_size self._precision = precision with self.name_scope(): self._dropout = nn.Dropout(rate=dropout) - self._start_index_dense = nn.Dense(units=1, in_units=in_units, - use_bias=False, flatten=False) + self._start_index_g = nn.Dense(units=1, in_units=8 * span_start_input_dim, + flatten=False) + self._start_index_m = nn.Dense(units=1, in_units=2 * span_start_input_dim, + flatten=False) self._end_index_lstm = LSTM(hidden_size=span_start_input_dim, num_layers=nlayers, dropout=dropout, bidirectional=biflag, input_size=2 * span_start_input_dim) - self._end_index_dense = nn.Dense(units=1, in_units=in_units, - use_bias=False, flatten=False) + self._end_index_g = nn.Dense(units=1, in_units=8 * span_start_input_dim, + flatten=False) + self._end_index_m = nn.Dense(units=1, in_units=2 * span_start_input_dim, + flatten=False) def begin_state(self, ctx, batch_sizes=None): if batch_sizes is None: @@ -224,19 +230,17 @@ def begin_state(self, ctx, batch_sizes=None): return state_list def hybrid_forward(self, F, x, m, mask, state, *args): # pylint: disable=arguments-differ - # setting batch size as the first dimension - start_index_input = F.transpose(F.concat(x, m, dim=2), axes=(1, 0, 2)) - start_index_input = self._dropout(start_index_input) - - start_index_dense_output = self._start_index_dense(start_index_input) + x = F.transpose(x, axes=(1, 0, 2)) - end_index_input_part, _ = self._end_index_lstm(m, state) - end_index_input = F.transpose(F.concat(x, end_index_input_part, dim=2), - axes=(1, 0, 2)) + start_index_dense_output = self._start_index_g(self._dropout(x)) + \ + self._start_index_m(self._dropout(F.transpose(m, + axes=(1, 0, 2)))) - end_index_input = self._dropout(end_index_input) - end_index_dense_output = self._end_index_dense(end_index_input) + m2, _ = self._end_index_lstm(m, state) + end_index_dense_output = self._end_index_g(self._dropout(x)) + \ + self._end_index_m(self._dropout(F.transpose(m2, + axes=(1, 0, 2)))) start_index_dense_output = F.squeeze(start_index_dense_output) start_index_dense_output_masked = start_index_dense_output + ((1 - mask) * @@ -266,6 +270,7 @@ def __init__(self, word_vocab, char_vocab, options, prefix=None, params=None): options.ctx_embedding_num_layers, options.highway_num_layers, options.embedding_size, + dropout=options.dropout, precision=options.precision, prefix="context_embedding") diff --git a/scripts/question_answering/tokenizer.py b/scripts/question_answering/tokenizer.py index 49ca53f406..d60eb9ff4c 100644 --- a/scripts/question_answering/tokenizer.py +++ b/scripts/question_answering/tokenizer.py @@ -58,12 +58,3 @@ def _process_tokens(temp_tokens): tokens.extend(re.split("([{}])".format("".join(splitters)), token)) return tokens - - @staticmethod - def _isascii(token): - try: - token.encode('ascii') - return True - - except UnicodeEncodeError: - return False diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index 4565684872..65d9d7c78f 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -18,6 +18,7 @@ # under the License. import math +import copy import multiprocessing import os from mxnet.gluon.loss import SoftmaxCrossEntropyLoss @@ -69,7 +70,7 @@ def transform_dataset(dataset, vocab_provider, options): A tuple of dataset, QuestionIdMapper and original json data for evaluation """ transformer = SQuADTransform(vocab_provider, options.q_max_len, - options.ctx_max_len, options.word_max_len, args.embedding_size) + options.ctx_max_len, options.word_max_len, options.embedding_size) processed_dataset = SimpleDataset([transformer(*record) for i, record in enumerate(dataset)]) return processed_dataset @@ -190,12 +191,12 @@ def run_training(net, dataloader, ctx, options): if options.rho: hyperparameters["rho"] = options.rho - trainer = Trainer(net.collect_params(), args.optimizer, hyperparameters, kvstore="device", - update_on_kvstore=False) + trainer = Trainer(net.collect_params(), options.optimizer, hyperparameters, + kvstore="device", update_on_kvstore=False) - if args.resume_training: + if options.resume_training: path = os.path.join(options.save_dir, - 'trainer_epoch{:d}.params'.format(args.resume_training - 1)) + 'trainer_epoch{:d}.params'.format(options.resume_training - 1)) trainer.load_states(path) loss_function = SoftmaxCrossEntropyLoss() @@ -204,9 +205,13 @@ def run_training(net, dataloader, ctx, options): train_start = time() avg_loss = mx.nd.zeros((1,), ctx=ctx[0], dtype=options.precision) iteration = 1 + max_dev_exact = -1 + max_dev_f1 = -1 + print("Starting training...") - for e in range(0 if not args.resume_training else args.resume_training, args.epochs): + for e in range(0 if not options.resume_training else options.resume_training, + options.epochs): avg_loss *= 0 # Zero average loss of each epoch ctx_embedding_begin_state_list = net.ctx_embedding.begin_state(ctx) @@ -259,13 +264,13 @@ def run_training(net, dataloader, ctx, options): for loss in losses: loss.backward() - if iteration == 1 and args.use_exponential_moving_average: + if iteration == 1 and options.use_exponential_moving_average: ema = PolyakAveraging(net.collect_params(), - args.exponential_moving_average_weight_decay) + options.exponential_moving_average_weight_decay) - if args.resume_training: + if options.resume_training: path = os.path.join(options.save_dir, 'ema_epoch{:d}.params'.format( - args.resume_training - 1)) + options.resume_training - 1)) ema.get_params().load(path) # in special mode we collect gradients and apply processing only after @@ -309,6 +314,20 @@ def run_training(net, dataloader, ctx, options): if ema is not None: ema.update() + if options.early_stop and \ + e == options.epochs - 1 and \ + iteration % options.log_interval == 0: + result = run_evaluate_mode(options, net, ema) + + if result["f1"] > max_dev_f1: + max_dev_f1 = result["f1"] + max_dev_exact = result["exact_match"] + print("New best evaluation results on dev dataset: {}".format(result)) + else: + print("Results starts decreasing. Stopping training...") + # Best parameters are saved as "-1" epoch + break + for l in losses: avg_loss += l.mean().as_in_context(avg_loss.context) @@ -479,96 +498,125 @@ def load_transformed_dataset(path): return processed_dataset -if __name__ == "__main__": - args = get_args() - args.batch_size = int(args.batch_size / len(get_context(args))) - print(args) - logging_config(args.save_dir) +def run_preprocess_mode(options): + # we use both datasets to create proper vocab + dataset_train = SQuAD(segment='train') + dataset_dev = SQuAD(segment='dev') - if args.preprocess: - if not args.preprocessed_dataset_path: - logging.error("Preprocessed_data_path attribute is not provided") - exit(1) + vocab_provider = VocabProvider([dataset_train, dataset_dev], options) + transformed_dataset = transform_dataset(dataset_train, vocab_provider, options=options) + save_transformed_dataset(transformed_dataset, options.preprocessed_dataset_path) - print("Running in preprocessing mode") + if options.preprocessed_val_dataset_path: + transformed_dataset = transform_dataset(dataset_dev, vocab_provider, options=options) + save_transformed_dataset(transformed_dataset, options.preprocessed_val_dataset_path) - # we use both datasets to create proper vocab - dataset_train = SQuAD(segment='train') - dataset_dev = SQuAD(segment='dev') - vocab_provider = VocabProvider([dataset_train, dataset_dev], args) - transformed_dataset = transform_dataset(dataset_train, vocab_provider, options=args) - save_transformed_dataset(transformed_dataset, args.preprocessed_dataset_path) +def run_training_mode(options): + dataset = SQuAD(segment='train') + dataset_val = SQuAD(segment='dev') + vocab_provider = VocabProvider([dataset, dataset_val], options) - if args.preprocessed_val_dataset_path: - transformed_dataset = transform_dataset(dataset_dev, vocab_provider, options=args) - save_transformed_dataset(transformed_dataset, args.preprocessed_val_dataset_path) + if options.preprocessed_dataset_path and isfile(options.preprocessed_dataset_path): + transformed_dataset = load_transformed_dataset(options.preprocessed_dataset_path) + else: + transformed_dataset = transform_dataset(dataset, vocab_provider, options=options) + save_transformed_dataset(transformed_dataset, options.preprocessed_dataset_path) - if args.train: - print("Running in training mode") + train_dataset, train_dataloader = get_record_per_answer_span(transformed_dataset, options) + word_vocab, char_vocab = get_vocabs(vocab_provider, options=options) + ctx = get_context(options) - dataset = SQuAD(segment='train') - dataset_val = SQuAD(segment='dev') - vocab_provider = VocabProvider([dataset, dataset_val], args) + net = BiDAFModel(word_vocab, char_vocab, options, prefix="bidaf") + net.cast(options.precision) + net.initialize(init.Xavier(magnitude=2.24), ctx=ctx) + net.hybridize(static_alloc=True) - if args.preprocessed_dataset_path and isfile(args.preprocessed_dataset_path): - transformed_dataset = load_transformed_dataset(args.preprocessed_dataset_path) - else: - transformed_dataset = transform_dataset(dataset, vocab_provider, options=args) - save_transformed_dataset(transformed_dataset, args.preprocessed_dataset_path) + if options.grad_req_add_mode: + net.collect_params().setattr('grad_req', 'add') - train_dataset, train_dataloader = get_record_per_answer_span(transformed_dataset, args) - word_vocab, char_vocab = get_vocabs(vocab_provider, options=args) - ctx = get_context(args) + if options.resume_training: + print("Resuming training from {} epoch".format(options.resume_training)) + params_path = os.path.join(options.save_dir, + 'epoch{:d}.params'.format(int(options.resume_training) - 1)) + net.load_parameters(params_path, ctx) - net = BiDAFModel(word_vocab, char_vocab, args, prefix="bidaf") - net.cast(args.precision) - net.initialize(init.Xavier(magnitude=2.24), ctx=ctx) - net.hybridize(static_alloc=True) + run_training(net, train_dataloader, ctx, options=options) - if args.grad_req_add_mode: - net.collect_params().setattr('grad_req', 'add') - if args.resume_training: - print("Resuming training from {} epoch".format(args.resume_training)) - params_path = os.path.join(args.save_dir, - 'epoch{:d}.params'.format(int(args.resume_training) - 1)) - net.load_parameters(params_path, ctx) +def run_evaluate_mode(options, existing_net=None, existing_ema=None): + train_dataset = SQuAD(segment='train') + dataset = SQuAD(segment='dev') - run_training(net, train_dataloader, ctx, options=args) + if existing_net: + # currently evaluate can work only with batch_size of 10 + options = copy.deepcopy(options) + options.batch_size = 10 - if args.evaluate: - print("Running in evaluation mode") + vocab_provider = VocabProvider([train_dataset, dataset], options) + mapper = QuestionIdMapper(dataset) - train_dataset = SQuAD(segment='train') - dataset = SQuAD(segment='dev') + if options.preprocessed_val_dataset_path and isfile(options.preprocessed_val_dataset_path): + transformed_dataset = load_transformed_dataset(options.preprocessed_val_dataset_path) + else: + transformed_dataset = transform_dataset(dataset, vocab_provider, options=options) + save_transformed_dataset(transformed_dataset, options.preprocessed_val_dataset_path) - vocab_provider = VocabProvider([train_dataset, dataset], args) - mapper = QuestionIdMapper(dataset) + word_vocab, char_vocab = get_vocabs(vocab_provider, options=options) + ctx = get_context(options) - if args.preprocessed_val_dataset_path and isfile(args.preprocessed_val_dataset_path): - transformed_dataset = load_transformed_dataset(args.preprocessed_val_dataset_path) - else: - transformed_dataset = transform_dataset(dataset, vocab_provider, options=args) - save_transformed_dataset(transformed_dataset, args.preprocessed_val_dataset_path) + evaluator = PerformanceEvaluator(BiDAFTokenizer(), transformed_dataset, + dataset._read_data(), mapper) - word_vocab, char_vocab = get_vocabs(vocab_provider, options=args) - ctx = get_context(args) + net = BiDAFModel(word_vocab, char_vocab, options, prefix="bidaf") - evaluator = PerformanceEvaluator(BiDAFTokenizer(), transformed_dataset, - dataset._read_data(), mapper) - net = BiDAFModel(word_vocab, char_vocab, args, prefix="bidaf") + if options.use_exponential_moving_average: + if existing_ema is None: + params_path = os.path.join(options.save_dir, + 'ema_epoch{:d}.params'.format(int(options.epochs) - 1)) + else: + save_ema_parameters(existing_ema, -1, options) + params_path = os.path.join(options.save_dir, + 'ema_epoch{:d}.params'.format(-1)) - if args.use_exponential_moving_average: - params_path = os.path.join(args.save_dir, - 'ema_epoch{:d}.params'.format(int(args.epochs) - 1)) - net.collect_params().load(params_path, ctx=ctx) + net.collect_params().load(params_path, ctx=ctx) + else: + if existing_net is None: + params_path = os.path.join(options.save_dir, + 'epoch{:d}.params'.format(int(options.epochs) - 1)) else: - params_path = os.path.join(args.save_dir, - 'epoch{:d}.params'.format(int(args.epochs) - 1)) - net.load_parameters(params_path, ctx=ctx) + save_model_parameters(existing_net, -1, options) + params_path = os.path.join(options.save_dir, + 'epoch{:d}.params'.format(-1)) + + net.load_parameters(params_path, ctx=ctx) + + net.hybridize(static_alloc=True) - net.hybridize(static_alloc=True) + result = evaluator.evaluate_performance(net, ctx, options) + return result + + +if __name__ == "__main__": + args = get_args() + args.batch_size = int(args.batch_size / len(get_context(args))) + print(args) + logging_config(args.save_dir) - result = evaluator.evaluate_performance(net, ctx, args) + if args.preprocess: + if not args.preprocessed_dataset_path: + logging.error("Preprocessed_data_path attribute is not provided") + exit(1) + + print("Running in preprocessing mode") + run_preprocess_mode(args) + + if args.train: + print("Running in training mode") + run_training_mode(args) + + if args.evaluate: + print("Running in evaluation mode") + result = run_evaluate_mode(args) print("Evaluation results on dev dataset: {}".format(result)) + diff --git a/scripts/question_answering/utils.py b/scripts/question_answering/utils.py index e52e237c1a..11caff9dd0 100644 --- a/scripts/question_answering/utils.py +++ b/scripts/question_answering/utils.py @@ -109,15 +109,17 @@ def get_args(): parser.add_argument('--answer_max_len', type=int, default=30, help='Maximum tokens in answer') parser.add_argument('--optimizer', type=str, default='adadelta', help='optimization algorithm') parser.add_argument('--lr', type=float, default=0.5, help='Initial learning rate') - parser.add_argument('--rho', type=float, default=0.9, + parser.add_argument('--rho', type=float, default=0.95, help='Adadelta decay rate for both squared gradients and delta.') parser.add_argument('--lr_warmup_steps', type=int, default=0, help='Defines how many iterations to spend on warming up learning rate') parser.add_argument('--clip', type=float, default=0, help='gradient clipping') - parser.add_argument('--weight_decay', type=float, default=0, + parser.add_argument('--weight_decay', type=float, default=0.0001, help='Weight decay for parameter updates') - parser.add_argument('--log_interval', type=int, default=100, metavar='N', + parser.add_argument('--log_interval', type=int, default=250, metavar='N', help='report interval') + parser.add_argument('--early_stop', default=False, action='store_true', + help='Apply early stopping') parser.add_argument('--resume_training', type=int, default=0, help='Resume training from this epoch number') parser.add_argument('--save_dir', type=str, default='out_dir', From 89988362af0d8d084a96b35d3001ca4d2bf03abd Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Mon, 22 Oct 2018 12:10:20 -0700 Subject: [PATCH 30/43] Can log results without early stopping --- .../train_question_answering.py | 48 +++++++++---------- scripts/question_answering/utils.py | 7 +-- 2 files changed, 27 insertions(+), 28 deletions(-) diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index 65d9d7c78f..8c17cf7a63 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -314,19 +314,24 @@ def run_training(net, dataloader, ctx, options): if ema is not None: ema.update() - if options.early_stop and \ - e == options.epochs - 1 and \ - iteration % options.log_interval == 0: - result = run_evaluate_mode(options, net, ema) - - if result["f1"] > max_dev_f1: - max_dev_f1 = result["f1"] - max_dev_exact = result["exact_match"] - print("New best evaluation results on dev dataset: {}".format(result)) - else: - print("Results starts decreasing. Stopping training...") - # Best parameters are saved as "-1" epoch - break + if e == options.epochs - 1 and \ + iteration > 0 and iteration % options.log_interval == 0: + evaluate_options = copy.deepcopy(options) + evaluate_options.batch_size = 10 + evaluate_options.epochs = iteration + result = run_evaluate_mode(evaluate_options, net, ema) + + print("Iteration {} evaluation results on dev dataset: {}".format(iteration, + result)) + if options.early_stop: + if result["f1"] > max_dev_f1: + max_dev_f1 = result["f1"] + max_dev_exact = result["exact_match"] + else: + print("Results starts decreasing - stopping training. " + "Best results are stored at {} params file"\ + .format(iteration - options.log_interval)) + break for l in losses: avg_loss += l.mean().as_in_context(avg_loss.context) @@ -548,11 +553,6 @@ def run_evaluate_mode(options, existing_net=None, existing_ema=None): train_dataset = SQuAD(segment='train') dataset = SQuAD(segment='dev') - if existing_net: - # currently evaluate can work only with batch_size of 10 - options = copy.deepcopy(options) - options.batch_size = 10 - vocab_provider = VocabProvider([train_dataset, dataset], options) mapper = QuestionIdMapper(dataset) @@ -575,9 +575,9 @@ def run_evaluate_mode(options, existing_net=None, existing_ema=None): params_path = os.path.join(options.save_dir, 'ema_epoch{:d}.params'.format(int(options.epochs) - 1)) else: - save_ema_parameters(existing_ema, -1, options) + save_ema_parameters(existing_ema, options.epochs, options) params_path = os.path.join(options.save_dir, - 'ema_epoch{:d}.params'.format(-1)) + 'ema_epoch{:d}.params'.format(options.epochs)) net.collect_params().load(params_path, ctx=ctx) else: @@ -585,16 +585,14 @@ def run_evaluate_mode(options, existing_net=None, existing_ema=None): params_path = os.path.join(options.save_dir, 'epoch{:d}.params'.format(int(options.epochs) - 1)) else: - save_model_parameters(existing_net, -1, options) + save_model_parameters(existing_net, options.epochs, options) params_path = os.path.join(options.save_dir, - 'epoch{:d}.params'.format(-1)) + 'epoch{:d}.params'.format(options.epochs)) net.load_parameters(params_path, ctx=ctx) net.hybridize(static_alloc=True) - - result = evaluator.evaluate_performance(net, ctx, options) - return result + return evaluator.evaluate_performance(net, ctx, options) if __name__ == "__main__": diff --git a/scripts/question_answering/utils.py b/scripts/question_answering/utils.py index 11caff9dd0..6c197d5440 100644 --- a/scripts/question_answering/utils.py +++ b/scripts/question_answering/utils.py @@ -116,10 +116,11 @@ def get_args(): parser.add_argument('--clip', type=float, default=0, help='gradient clipping') parser.add_argument('--weight_decay', type=float, default=0.0001, help='Weight decay for parameter updates') - parser.add_argument('--log_interval', type=int, default=250, metavar='N', - help='report interval') + parser.add_argument('--log_interval', type=int, default=100, metavar='N', + help='Report interval applied to last epoch only') parser.add_argument('--early_stop', default=False, action='store_true', - help='Apply early stopping') + help='Apply early stopping for the last epoch. ' + 'Should be used with log_interval') parser.add_argument('--resume_training', type=int, default=0, help='Resume training from this epoch number') parser.add_argument('--save_dir', type=str, default='out_dir', From 2528b8ea8d9259bb7df60fbd8c4dbbe9c3986856 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Wed, 24 Oct 2018 11:16:43 -0700 Subject: [PATCH 31/43] NLTK tokenizer is used to fix [citation needed] --- scripts/question_answering/bidaf.py | 9 +-- scripts/question_answering/data_processing.py | 15 ++++ .../performance_evaluator.py | 1 + .../question_answering/question_answering.py | 23 +++++- .../question_answering/similarity_function.py | 27 ++++--- scripts/question_answering/tokenizer.py | 14 ++-- .../train_question_answering.py | 78 +++++++++++++++---- scripts/question_answering/utils.py | 14 ++-- scripts/tests/test_question_answering.py | 25 ++++++ 9 files changed, 154 insertions(+), 52 deletions(-) diff --git a/scripts/question_answering/bidaf.py b/scripts/question_answering/bidaf.py index 8ee7884599..a9f573fa74 100644 --- a/scripts/question_answering/bidaf.py +++ b/scripts/question_answering/bidaf.py @@ -32,7 +32,6 @@ class BidirectionalAttentionFlow(gluon.HybridBlock): """ def __init__(self, - attention_similarity_function, batch_size, passage_length, question_length, @@ -46,9 +45,6 @@ def __init__(self, self._question_length = question_length self._encoding_dim = encoding_dim self._precision = precision - self._matrix_attention = AttentionFlow(attention_similarity_function, - batch_size, passage_length, question_length, - encoding_dim) def _get_big_negative_value(self): if self._precision == 'float16': @@ -62,13 +58,12 @@ def _get_small_positive_value(self): else: return np.finfo(np.float32).eps - def hybrid_forward(self, F, encoded_passage, encoded_question, question_mask, passage_mask): + def hybrid_forward(self, F, passage_question_similarity, + encoded_passage, encoded_question, question_mask, passage_mask): # pylint: disable=arguments-differ """ """ - # Shape: (batch_size, passage_length, question_length) - passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) passage_question_similarity_shape = (self._batch_size, self._passage_length, self._question_length) diff --git a/scripts/question_answering/data_processing.py b/scripts/question_answering/data_processing.py index 9356d39424..d8a9aabb70 100644 --- a/scripts/question_answering/data_processing.py +++ b/scripts/question_answering/data_processing.py @@ -245,6 +245,21 @@ def get_word_level_vocab(self, embedding_size): word_level_vocab.set_embedding( nlp.embedding.create('glove', source='glove.6B.{}d'.format(embedding_size))) + # count = 0 + # words_no_embedding = [] + # for i in range(len(word_level_vocab)): + # if (word_level_vocab.embedding.idx_to_vec[i].sum() != 0).asscalar(): + # count += 1 + # else: + # words_no_embedding.append(word_level_vocab.embedding.idx_to_token[i]) + # + # with open("no_embedding_words.txt", "w") as f: + # for word in words_no_embedding: + # f.write(word + "\n") + # + # print("word_level_vocab {}, word_level_vocab.set_embedding {}".format( + # len(word_level_vocab), count)) + if self._options.word_vocab_path: pickle.dump(word_level_vocab, open(self._options.word_vocab_path, "wb")) diff --git a/scripts/question_answering/performance_evaluator.py b/scripts/question_answering/performance_evaluator.py index aeeb5aea1c..910122c4dd 100644 --- a/scripts/question_answering/performance_evaluator.py +++ b/scripts/question_answering/performance_evaluator.py @@ -64,6 +64,7 @@ def evaluate_performance(self, net, ctx, options): for r in self._evaluation_dataset]) eval_dataloader = DataLoader(eval_dataset, batch_size=options.batch_size, last_batch='keep', + pin_memory=True, num_workers=(multiprocessing.cpu_count() - len(ctx) - 2)) for i, data in enumerate(eval_dataloader): diff --git a/scripts/question_answering/question_answering.py b/scripts/question_answering/question_answering.py index 3037e96347..094c565963 100644 --- a/scripts/question_answering/question_answering.py +++ b/scripts/question_answering/question_answering.py @@ -18,8 +18,10 @@ # under the License. """BiDAF model blocks""" +from scripts.question_answering.attention_flow import AttentionFlow from scripts.question_answering.bidaf import BidirectionalAttentionFlow -from scripts.question_answering.similarity_function import DotProductSimilarity +from scripts.question_answering.similarity_function import DotProductSimilarity, CosineSimilarity, \ + LinearSimilarity from scripts.question_answering.utils import get_very_negative_number __all__ = ['BiDAFEmbedding', 'BiDAFModelingLayer', 'BiDAFOutputLayer', 'BiDAFModel'] @@ -274,9 +276,18 @@ def __init__(self, word_vocab, char_vocab, options, prefix=None, params=None): precision=options.precision, prefix="context_embedding") + self.similarity_function = LinearSimilarity(array_1_dim=6 * options.embedding_size, + array_2_dim=1, + combination="x,y,x*y") + + self.matrix_attention = AttentionFlow(self.similarity_function, + options.batch_size, + options.ctx_max_len, + options.q_max_len, + 2 * options.embedding_size) + # we multiple embedding_size by 2 because we use bidirectional embedding - self.attention_layer = BidirectionalAttentionFlow(DotProductSimilarity(), - options.batch_size, + self.attention_layer = BidirectionalAttentionFlow(options.batch_size, options.ctx_max_len, options.q_max_len, 2 * options.embedding_size, @@ -314,7 +325,11 @@ def hybrid_forward(self, F, qw, cw, qc, cc, q_mask = qw != 0 ctx_mask = cw != 0 - attention_layer_output = self.attention_layer(ctx_embedding_output, + passage_question_similarity = self.matrix_attention(ctx_embedding_output, + q_embedding_output).sum(axis=-1) + + attention_layer_output = self.attention_layer(passage_question_similarity, + ctx_embedding_output, q_embedding_output, q_mask, ctx_mask) diff --git a/scripts/question_answering/similarity_function.py b/scripts/question_answering/similarity_function.py index 0b52e24bf7..5f76cb9906 100644 --- a/scripts/question_answering/similarity_function.py +++ b/scripts/question_answering/similarity_function.py @@ -18,7 +18,7 @@ # under the License. import mxnet as mx -from mxnet import gluon +from mxnet import gluon, initializer from mxnet.gluon import nn, Parameter from .utils import combine_tensors @@ -78,8 +78,8 @@ class CosineSimilarity(SimilarityFunction): """ def hybrid_forward(self, F, array_1, array_2): - normalized_array_1 = array_1 / F.norm(array_1, axis=-1, keepdims=True) - normalized_array_2 = array_2 / F.norm(array_2, axis=-1, keepdims=True) + normalized_array_1 = F.broadcast_div(array_1, F.norm(array_1, axis=-1, keepdims=True)) + normalized_array_2 = F.broadcast_div(array_2, F.norm(array_2, axis=-1, keepdims=True)) return (normalized_array_1 * normalized_array_2).sum(axis=-1) @@ -164,21 +164,26 @@ def __init__(self, activation='linear', **kwargs): super(LinearSimilarity, self).__init__(**kwargs) - self._combination = combination - self._weight_matrix = Parameter(name="weight_matrix", - shape=(array_1_dim, array_2_dim), init=mx.init.Uniform()) - self._bias = Parameter(name="bias", shape=(array_1_dim,), init=mx.init.Zero()) + self.combination = combination if activation == 'linear': self._activation = None else: self._activation = nn.Activation(activation) - def hybrid_forward(self, F, array_1, array_2): - combined_tensors = combine_tensors(self._combination, [array_1, array_1]) - dot_product = F.broadcast_mull(combined_tensors, self._weight_matrix) + with self.name_scope(): + self.weight_matrix = self.params.get("weight_matrix", + shape=(array_1_dim, array_2_dim), + init=initializer.Uniform()) + self.bias = self.params.get("bias", + shape=(array_1_dim,), + init=initializer.Zero()) + + def hybrid_forward(self, F, array_1, array_2, weight_matrix, bias): + combined_tensors = combine_tensors(F, self.combination, [array_1, array_2]) + dot_product = F.batch_dot(combined_tensors, weight_matrix) if not self._activation: return dot_product - return self._activation(dot_product + self._bias) + return self._activation(dot_product + bias) diff --git a/scripts/question_answering/tokenizer.py b/scripts/question_answering/tokenizer.py index d60eb9ff4c..31a1953499 100644 --- a/scripts/question_answering/tokenizer.py +++ b/scripts/question_answering/tokenizer.py @@ -18,12 +18,12 @@ # under the License. import re -from gluonnlp.data import SpacyTokenizer +import nltk class BiDAFTokenizer: - def __init__(self, base_tokenizer=SpacyTokenizer()): - self._base_tokenizer = base_tokenizer + def __init__(self): + pass def __call__(self, sample, lower_case=False): """ @@ -38,10 +38,12 @@ def __call__(self, sample, lower_case=False): ret : list of strs List of tokens """ - sample = sample.replace('\'\'', '\" ').replace(r'``', '\" ')\ - .replace(u'\u000A', ' ').replace(u'\u00A0', ' ') - tokens = self._base_tokenizer(sample) + sample = sample.replace('\n', ' ').replace(u'\u000A', '').replace(u'\u00A0', '') + + tokens = [token.replace("''", '"').replace("``", '"') for token in + nltk.word_tokenize(sample)] tokens = BiDAFTokenizer._process_tokens(tokens) + tokens = [token for token in tokens if len(token) > 0] if lower_case: tokens = [token.lower() for token in tokens] diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index 8c17cf7a63..0d9a9e97d3 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -47,12 +47,8 @@ from scripts.question_answering.tokenizer import BiDAFTokenizer from scripts.question_answering.utils import logging_config, get_args -np.random.seed(100) -random.seed(100) -mx.random.seed(10000) - -def transform_dataset(dataset, vocab_provider, options): +def transform_dataset(dataset, vocab_provider, options, enable_filtering=False): """Get transformed dataset Parameters @@ -63,15 +59,50 @@ def transform_dataset(dataset, vocab_provider, options): Vocabulary provider options : `Namespace` Data transformation arguments + enable_filtering : `Bool` + Remove data that doesn't match BiDAF model requirements Returns ------- data : Tuple A tuple of dataset, QuestionIdMapper and original json data for evaluation """ + tokenizer = vocab_provider.get_tokenizer() transformer = SQuADTransform(vocab_provider, options.q_max_len, options.ctx_max_len, options.word_max_len, options.embedding_size) - processed_dataset = SimpleDataset([transformer(*record) for i, record in enumerate(dataset)]) + + transformed_records = [] + long_context = 0 + long_question = 0 + + for i, record in enumerate(dataset): + if enable_filtering: + tokenized_question = tokenizer(record[2], lower_case=True) + # we don't need to dispose of context as long as the answer is still + # present in the context after it is trimmed + # tokenized_context = tokenizer(record[3], lower_case=True) + # + # if len(tokenized_context) > options.ctx_max_len: + # long_context += 1 + # continue + + # but we don't know if the question is still meaningful + if len(tokenized_question) > options.q_max_len: + long_question += 1 + continue + + transformed_record = transformer(*record) + + # if answer end index is after ctx_max_len token or + # it is after q_max_len token we do not use this record + if enable_filtering and transformed_record[6][0][1] >= options.ctx_max_len: + continue + + transformed_records.append(transformed_record) + + processed_dataset = SimpleDataset(transformed_records) + print("{}/{} records. Too long context {}, too long query {}".format( + len(processed_dataset), i + 1, long_context, long_question)) return processed_dataset @@ -116,7 +147,8 @@ def get_record_per_answer_span(processed_dataset, options): dataloader = DataLoader(loadable_data, batch_size=options.batch_size * len(get_context(options)), shuffle=True, - last_batch='discard', + last_batch='rollover', + pin_memory=True, num_workers=(multiprocessing.cpu_count() - len(get_context(options)) - 2)) @@ -207,6 +239,8 @@ def run_training(net, dataloader, ctx, options): iteration = 1 max_dev_exact = -1 max_dev_f1 = -1 + max_iteration = -1 + early_stop_tries = 0 print("Starting training...") @@ -307,14 +341,15 @@ def run_training(net, dataloader, ctx, options): for dest in destination: source.copyto(dest) - trainer.update(scailing_coeff) + trainer.update(scailing_coeff, ignore_stale_grad=True) else: - trainer.step(scailing_coeff) + trainer.step(scailing_coeff, ignore_stale_grad=True) if ema is not None: ema.update() if e == options.epochs - 1 and \ + options.log_interval > 0 and \ iteration > 0 and iteration % options.log_interval == 0: evaluate_options = copy.deepcopy(options) evaluate_options.batch_size = 10 @@ -327,11 +362,18 @@ def run_training(net, dataloader, ctx, options): if result["f1"] > max_dev_f1: max_dev_f1 = result["f1"] max_dev_exact = result["exact_match"] + max_iteration = iteration + early_stop_tries = 0 else: - print("Results starts decreasing - stopping training. " - "Best results are stored at {} params file"\ - .format(iteration - options.log_interval)) - break + if early_stop_tries < options.early_stop: + early_stop_tries += 1 + print("Results decreased for {} times".format(early_stop_tries)) + else: + print("Results decreased for {} times. Stop training. " + "Best results are stored at {} params file. F1={}, EM={}"\ + .format(options.early_stop + 1, max_iteration, + max_dev_f1, max_dev_exact)) + break for l in losses: avg_loss += l.mean().as_in_context(avg_loss.context) @@ -509,7 +551,8 @@ def run_preprocess_mode(options): dataset_dev = SQuAD(segment='dev') vocab_provider = VocabProvider([dataset_train, dataset_dev], options) - transformed_dataset = transform_dataset(dataset_train, vocab_provider, options=options) + transformed_dataset = transform_dataset(dataset_train, vocab_provider, options=options, + enable_filtering=True) save_transformed_dataset(transformed_dataset, options.preprocessed_dataset_path) if options.preprocessed_val_dataset_path: @@ -525,7 +568,8 @@ def run_training_mode(options): if options.preprocessed_dataset_path and isfile(options.preprocessed_dataset_path): transformed_dataset = load_transformed_dataset(options.preprocessed_dataset_path) else: - transformed_dataset = transform_dataset(dataset, vocab_provider, options=options) + transformed_dataset = transform_dataset(dataset, vocab_provider, options=options, + enable_filtering=True) save_transformed_dataset(transformed_dataset, options.preprocessed_dataset_path) train_dataset, train_dataloader = get_record_per_answer_span(transformed_dataset, options) @@ -534,8 +578,8 @@ def run_training_mode(options): net = BiDAFModel(word_vocab, char_vocab, options, prefix="bidaf") net.cast(options.precision) - net.initialize(init.Xavier(magnitude=2.24), ctx=ctx) - net.hybridize(static_alloc=True) + net.initialize(init.Xavier(), ctx=ctx) + # net.hybridize(static_alloc=True) if options.grad_req_add_mode: net.collect_params().setattr('grad_req', 'add') diff --git a/scripts/question_answering/utils.py b/scripts/question_answering/utils.py index 6c197d5440..e212798979 100644 --- a/scripts/question_answering/utils.py +++ b/scripts/question_answering/utils.py @@ -109,18 +109,18 @@ def get_args(): parser.add_argument('--answer_max_len', type=int, default=30, help='Maximum tokens in answer') parser.add_argument('--optimizer', type=str, default='adadelta', help='optimization algorithm') parser.add_argument('--lr', type=float, default=0.5, help='Initial learning rate') - parser.add_argument('--rho', type=float, default=0.95, + parser.add_argument('--rho', type=float, default=0.9, help='Adadelta decay rate for both squared gradients and delta.') parser.add_argument('--lr_warmup_steps', type=int, default=0, help='Defines how many iterations to spend on warming up learning rate') parser.add_argument('--clip', type=float, default=0, help='gradient clipping') - parser.add_argument('--weight_decay', type=float, default=0.0001, + parser.add_argument('--weight_decay', type=float, default=0.0005, help='Weight decay for parameter updates') parser.add_argument('--log_interval', type=int, default=100, metavar='N', help='Report interval applied to last epoch only') - parser.add_argument('--early_stop', default=False, action='store_true', - help='Apply early stopping for the last epoch. ' - 'Should be used with log_interval') + parser.add_argument('--early_stop', type=int, default=4, + help='Apply early stopping for the last epoch. Stop after # of consequent ' + '# of times F1 is lower than max. Should be used with log_interval') parser.add_argument('--resume_training', type=int, default=0, help='Resume training from this epoch number') parser.add_argument('--save_dir', type=str, default='out_dir', @@ -209,7 +209,7 @@ def _get_combination(combination, tensors): raise NotImplementedError -def combine_tensors(combination, tensors): +def combine_tensors(F, combination, tensors): """ Combines a list of tensors using element-wise operations and concatenation, specified by a ``combination`` string. The string refers to (1-indexed) positions in the input tensor list, @@ -235,7 +235,7 @@ def combine_tensors(combination, tensors): """ combination = combination.replace('x', '1').replace('y', '2') to_concatenate = [_get_combination(piece, tensors) for piece in combination.split(',')] - return nd.concat(to_concatenate, dim=-1) + return F.concat(*to_concatenate, dim=-1) def masked_softmax(F, vector, mask, epsilon): diff --git a/scripts/tests/test_question_answering.py b/scripts/tests/test_question_answering.py index e9dc1c5c32..f2df79131d 100644 --- a/scripts/tests/test_question_answering.py +++ b/scripts/tests/test_question_answering.py @@ -393,6 +393,22 @@ def test_get_answer_spans_after_comma(): assert result == [(23, 23)] +def test_get_answer_spans_after_quotes(): + tokenizer = BiDAFTokenizer() + + context = "In the film Knute Rockne, All American, Knute Rockne (played by Pat O'Brien) delivers the famous ""Win one for the Gipper"" speech, at which point the background music swells with the ""Notre Dame Victory March"". George Gipp was played by Ronald Reagan, whose nickname ""The Gipper"" was derived from this role. This scene was parodied in the movie Airplane! with the same background music, only this time honoring George Zipp, one of Ted Striker's former comrades. The song also was prominent in the movie Rudy, with Sean Astin as Daniel ""Rudy"" Ruettiger, who harbored dreams of playing football at the University of Notre Dame despite significant obstacles." + context_tokens = tokenizer(context, lower_case=True) + indices = SQuADTransform.get_char_indices(context, context_tokens) + + answer_start_char_index = 267 + answer = "The Gipper" + + result = SQuADTransform._get_answer_spans(context, context_tokens, + [answer], [answer_start_char_index]) + + assert result == [(54, 55)] + + def test_get_char_indices(): context = "to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary." tokenizer = BiDAFTokenizer() @@ -402,6 +418,15 @@ def test_get_char_indices(): assert len(result) == len(context_tokens) +def test_tokenizer_split_new_lines(): + context = "that are of equal energy\u2014i.e., degenerate\u2014is a configuration termed a spin triplet state. Hence, the ground state of the O\n2 molecule is referred to as triplet oxygen" + + tokenizer = BiDAFTokenizer() + context_tokens = tokenizer(context) + + assert len(context_tokens) == 35 + + def test_polyak_averaging(): net = nn.HybridSequential() net.add(nn.Dense(5), nn.Dense(3), nn.Dense(2)) From 2a28253ca00abcd5b2e765ed39f205891db32c19 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Thu, 25 Oct 2018 10:10:47 -0700 Subject: [PATCH 32/43] Bidaf similarity is used --- .../question_answering/question_answering.py | 7 ++++++- .../question_answering/similarity_function.py | 18 ++++++++++++------ .../train_question_answering.py | 6 +++--- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/scripts/question_answering/question_answering.py b/scripts/question_answering/question_answering.py index 094c565963..0b6ee11c3c 100644 --- a/scripts/question_answering/question_answering.py +++ b/scripts/question_answering/question_answering.py @@ -326,7 +326,12 @@ def hybrid_forward(self, F, qw, cw, qc, cc, ctx_mask = cw != 0 passage_question_similarity = self.matrix_attention(ctx_embedding_output, - q_embedding_output).sum(axis=-1) + q_embedding_output) + + passage_question_similarity = passage_question_similarity.reshape( + shape=(self._options.batch_size, + self._options.ctx_max_len, + self._options.q_max_len)) attention_layer_output = self.attention_layer(passage_question_similarity, ctx_embedding_output, diff --git a/scripts/question_answering/similarity_function.py b/scripts/question_answering/similarity_function.py index 5f76cb9906..181edf0694 100644 --- a/scripts/question_answering/similarity_function.py +++ b/scripts/question_answering/similarity_function.py @@ -160,11 +160,15 @@ class LinearSimilarity(SimilarityFunction): def __init__(self, array_1_dim, array_2_dim, + use_bias=False, combination='x,y', activation='linear', **kwargs): super(LinearSimilarity, self).__init__(**kwargs) self.combination = combination + self.use_bias = use_bias + self.array_1_dim = array_1_dim + self.array_2_dim = array_2_dim if activation == 'linear': self._activation = None @@ -173,15 +177,17 @@ def __init__(self, with self.name_scope(): self.weight_matrix = self.params.get("weight_matrix", - shape=(array_1_dim, array_2_dim), + shape=(array_2_dim, array_1_dim), init=initializer.Uniform()) - self.bias = self.params.get("bias", - shape=(array_1_dim,), - init=initializer.Zero()) + if use_bias: + self.bias = self.params.get("bias", + shape=(array_2_dim,), + init=initializer.Zero()) - def hybrid_forward(self, F, array_1, array_2, weight_matrix, bias): + def hybrid_forward(self, F, array_1, array_2, weight_matrix, bias=None): combined_tensors = combine_tensors(F, self.combination, [array_1, array_2]) - dot_product = F.batch_dot(combined_tensors, weight_matrix) + dot_product = F.FullyConnected(combined_tensors, weight_matrix, bias=bias, flatten=False, + no_bias=not self.use_bias, num_hidden=self.array_2_dim) if not self._activation: return dot_product diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index 0d9a9e97d3..555380918d 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -341,9 +341,9 @@ def run_training(net, dataloader, ctx, options): for dest in destination: source.copyto(dest) - trainer.update(scailing_coeff, ignore_stale_grad=True) + trainer.update(scailing_coeff) else: - trainer.step(scailing_coeff, ignore_stale_grad=True) + trainer.step(scailing_coeff) if ema is not None: ema.update() @@ -579,7 +579,7 @@ def run_training_mode(options): net = BiDAFModel(word_vocab, char_vocab, options, prefix="bidaf") net.cast(options.precision) net.initialize(init.Xavier(), ctx=ctx) - # net.hybridize(static_alloc=True) + net.hybridize() if options.grad_req_add_mode: net.collect_params().setattr('grad_req', 'add') From f62e1c5ced3a184b9c4c7f379765a02359468350 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Thu, 25 Oct 2018 15:38:17 -0700 Subject: [PATCH 33/43] Add comments to code --- scripts/question_answering/data_processing.py | 150 ++++++++++++++---- .../exponential_moving_average.py | 26 ++- .../performance_evaluator.py | 6 + .../question_answering/question_answering.py | 57 ++++++- .../question_answering/question_id_mapper.py | 24 +++ scripts/question_answering/tokenizer.py | 23 ++- .../train_question_answering.py | 110 +++++++++---- scripts/question_answering/utils.py | 4 +- 8 files changed, 320 insertions(+), 80 deletions(-) diff --git a/scripts/question_answering/data_processing.py b/scripts/question_answering/data_processing.py index d8a9aabb70..0b94cc8872 100644 --- a/scripts/question_answering/data_processing.py +++ b/scripts/question_answering/data_processing.py @@ -58,6 +58,29 @@ def __call__(self, record_index, question_id, question, context, answer_list, """ Method converts text into numeric arrays based on Vocabulary. Answers are not processed, as they are not needed in input + + Parameters + ---------- + + record_index: int + Index of the record + question_id: str + Question Id + question: str + Question + context: str + Context + answer_list: list[str] + List of answers + answer_start_list: list[int] + List of start indices of answers + + Returns + ------- + record: Tuple + A tuple containing record_index [int], question_id [str], question_words_nd [NDArray], + context_words_nd [NDArray], question_chars_nd [NDArray[, context_chars_nd [NDArray], + answer_spans [list[Tuple]] """ question_tokens = self._tokenizer(question, lower_case=True) context_tokens = self._tokenizer(context, lower_case=True) @@ -91,13 +114,21 @@ def _get_answer_spans(context, context_tokens, answer_list, answer_start_list): """Find all answer spans from the context, returning start_index and end_index. Each index is a index of a token - :param list[str] context_tokens: Tokenized paragraph - :param list[str] answer_list: List of all answers + Parameters + ---------- + context: str + Context + context_tokens: list[str] + Tokenized context + answer_list: list[str] + List of answers + answer_start_list: list[int] + List of answers start indices Returns ------- - List[Tuple] - list of Tuple(answer_start_index answer_end_index) per question + answer_spans: List[Tuple] + List of Tuple(answer_start_index answer_end_index) per question """ answer_spans = [] # SQuAD answers doesn't always match to used tokens in the context. Sometimes there is only @@ -134,9 +165,17 @@ def _get_answer_spans(context, context_tokens, answer_list, answer_start_list): def get_char_indices(text, text_tokens): """Match token with character indices - :param str text: Text - :param List[str] text_tokens: Tokens of the text - :return: List of char_indexes where the order equals to token index + Parameters + ---------- + text: str + Text + text_tokens: list[str] + Tokens of the text + + Returns + ------- + char_indices_per_token: List[Tuple] + List of (start_index, end_index) of characters where the position equals to token index """ char_indices_per_token = [] current_index = 0 @@ -152,9 +191,17 @@ def get_char_indices(text, text_tokens): def _pad_to_max_char_length(self, item, max_item_length): """Pads all tokens to maximum size - :param NDArray item: matrix of indices - :param int max_item_length: maximum length of a token - :return: + Parameters + ---------- + item: NDArray + Matrix of indices + max_item_length: int + Maximum length of a token + + Returns + ------- + NDArray + Padded NDArray """ # expand dimensions to 4 and turn to float32, because nd.pad can work only with 4 dims data_expanded = item.reshape(1, 1, item.shape[0], item.shape[1]).astype(np.float32) @@ -174,9 +221,17 @@ def _pad_to_max_char_length(self, item, max_item_length): def _pad_to_max_word_length(item, max_length): """Pads sentences to maximum length - :param NDArray item: vector of words - :param int max_length: Maximum length of question/context - :return: + Parameters + ---------- + item: NDArray + Vector of words + max_length: int + Maximum length of question/context + + Returns + ------- + NDArray + Padded vector of tokens """ data_nd = nd.array(item, dtype=np.float32) # expand dimensions to 4 and turn to float32, because nd.pad can work only with 4 dims @@ -200,7 +255,14 @@ def __init__(self, datasets, options, tokenizer=BiDAFTokenizer()): self._tokenizer = tokenizer def get_tokenizer(self): - """Provides tokenizer used to create vocab""" + """Provides tokenizer used to create vocab + + Returns + ------- + tokenizer: Tokenizer + Tokenizer + + """ return self._tokenizer def get_char_level_vocab(self): @@ -245,20 +307,13 @@ def get_word_level_vocab(self, embedding_size): word_level_vocab.set_embedding( nlp.embedding.create('glove', source='glove.6B.{}d'.format(embedding_size))) - # count = 0 - # words_no_embedding = [] - # for i in range(len(word_level_vocab)): - # if (word_level_vocab.embedding.idx_to_vec[i].sum() != 0).asscalar(): - # count += 1 - # else: - # words_no_embedding.append(word_level_vocab.embedding.idx_to_token[i]) - # - # with open("no_embedding_words.txt", "w") as f: - # for word in words_no_embedding: - # f.write(word + "\n") - # - # print("word_level_vocab {}, word_level_vocab.set_embedding {}".format( - # len(word_level_vocab), count)) + count = 0 + + for i in range(len(word_level_vocab)): + if (word_level_vocab.embedding.idx_to_vec[i].sum() != 0).asscalar(): + count += 1 + + print("{}/{} words have embeddings".format(count, len(word_level_vocab))) if self._options.word_vocab_path: pickle.dump(word_level_vocab, open(self._options.word_vocab_path, "wb")) @@ -267,11 +322,37 @@ def get_word_level_vocab(self, embedding_size): @staticmethod def _create_squad_vocab(all_tokens): + """Provides vocabulary based on list of tokens + + Parameters + ---------- + + all_tokens: List[str] + List of all tokens + + Returns + ------- + Vocab + Vocabulary + """ counter = data.count_tokens(all_tokens) vocab = Vocab(counter) return vocab def _get_all_word_tokens(self, dataset): + """Provides all words from context and question of the dataset + + Parameters + ---------- + + dataset: SimpleDataset + Dataset of SQuAD + + Returns + ------- + all_tokens: list[str] + List of all words + """ all_tokens = [] for data_item in dataset: @@ -281,6 +362,19 @@ def _get_all_word_tokens(self, dataset): return all_tokens def _get_all_char_tokens(self, dataset): + """Provides all characters from context and question of the dataset + + Parameters + ---------- + + dataset: SimpleDataset + Dataset of SQuAD + + Returns + ------- + all_tokens: list[str] + List of all characters + """ all_tokens = [] for data_item in dataset: diff --git a/scripts/question_answering/exponential_moving_average.py b/scripts/question_answering/exponential_moving_average.py index e03612f2a7..651ddc5d73 100644 --- a/scripts/question_answering/exponential_moving_average.py +++ b/scripts/question_answering/exponential_moving_average.py @@ -24,6 +24,8 @@ class PolyakAveraging: + """Class to do Polyak averaging based on this paper + http://www.meyn.ece.ufl.edu/archive/spm_files/Courses/ECE555-2011/555media/poljud92.pdf""" def __init__(self, params, decay): self._params = params self._decay = decay @@ -47,18 +49,26 @@ def update(self): self._decay * polyak_param.data(mx.cpu())) def get_params(self): - """ - :return: returns the averaged parameters - :rtype: gluon.ParameterDict + """ Provides averaged parameters + + Returns + ------- + gluon.ParameterDict + Averaged parameters """ return self._polyak_params_dict def _param_data_to_cpu(self, param): - """ - Returns a copy (on CPU context) of the data held in some context of given parameter. + """Returns a copy (on CPU context) of the data held in some context of given parameter. + + Parameters + ---------- + param: gluon.Parameter + Parameter's whose data needs to be copied. - :param gluon.Parameter param: parameter's whose data needs to be copied. - :return: copy of data on CPU context. - :rtype: nd.NDArray + Returns + ------- + NDArray + Copy of data on CPU context. """ return param.list_data()[0].copyto(mx.cpu()) diff --git a/scripts/question_answering/performance_evaluator.py b/scripts/question_answering/performance_evaluator.py index 910122c4dd..38d2266245 100644 --- a/scripts/question_answering/performance_evaluator.py +++ b/scripts/question_answering/performance_evaluator.py @@ -27,6 +27,7 @@ class PerformanceEvaluator: + """Plugin to run prediction and performance evaluation via official eval script""" def __init__(self, tokenizer, evaluation_dataset, json_data, question_id_mapper): self._tokenizer = tokenizer self._evaluation_dataset = evaluation_dataset @@ -169,6 +170,11 @@ def _get_indices(begin, end, answer_mask_matrix): input tensor with shape `(batch_size, context_sequence_length)` end : NDArray input tensor with shape `(batch_size, context_sequence_length)` + + Returns + ------- + prediction: Tuple + Tuple containing first and last token indices of the answer """ begin_hat = begin.reshape(begin.shape + (1,)) end_hat = end.reshape(end.shape + (1,)) diff --git a/scripts/question_answering/question_answering.py b/scripts/question_answering/question_answering.py index 0b6ee11c3c..4d955fadb4 100644 --- a/scripts/question_answering/question_answering.py +++ b/scripts/question_answering/question_answering.py @@ -76,6 +76,13 @@ def __init__(self, batch_size, word_vocab, char_vocab, max_seq_len, dropout=dropout) def init_embeddings(self, lock_gradients): + """Initialize words embeddings with provided embedding values + + Parameters + ---------- + lock_gradients: bool + Flag to stop parameters from being trained + """ self._word_embedding.weight.set_data(self._word_vocab.embedding.idx_to_vec) if lock_gradients: @@ -156,6 +163,21 @@ def __init__(self, batch_size, input_dim=100, nlayers=2, biflag=True, bidirectional=biflag, input_size=800) def begin_state(self, ctx, batch_sizes=None): + """Provides begin state for the layer's modeling_layer block + + Parameters + ---------- + ctx: list[Context] + List of contexts to be used + + batch_sizes: list[int] + List of batch-sizes per context + + Returns + ------- + state_list: list + List of states + """ if batch_sizes is None: batch_sizes = [self._batch_size] * len(ctx) @@ -173,7 +195,7 @@ def hybrid_forward(self, F, x, state, *args): class BiDAFOutputLayer(HybridBlock): """ ``BiDAFOutputLayer`` produces the final prediction of an answer. The output is a tuple of - start index and end index of the answer in the paragraph per each batch. + start and end index of token in the paragraph per each batch. It accepts 2 inputs: `x` : the output of Attention layer of shape: @@ -184,10 +206,10 @@ class BiDAFOutputLayer(HybridBlock): Parameters ---------- + batch_size : `int` + Size of a batch span_start_input_dim : `int`, default 100 The number of features in the hidden state h of LSTM - units : `int`, default 4 * ``span_start_input_dim`` - Number of hidden units of `Dense` layer nlayers : `int`, default 1 Number of recurrent layers. biflag: `bool`, default True @@ -222,6 +244,21 @@ def __init__(self, batch_size, span_start_input_dim=100, nlayers=1, biflag=True, flatten=False) def begin_state(self, ctx, batch_sizes=None): + """Provides begin state for the layer's end_index_lstm block + + Parameters + ---------- + ctx: list[Context] + List of contexts to be used + + batch_sizes: list[int] + List of batch-sizes per context + + Returns + ------- + state_list: list + List of states + """ if batch_sizes is None: batch_sizes = [self._batch_size] * len(ctx) @@ -257,14 +294,24 @@ def hybrid_forward(self, F, x, m, mask, state, *args): # pylint: disable=argume class BiDAFModel(HybridBlock): - """Bidirectional attention flow model for Question answering + """Bidirectional attention flow model for Question answering. Implemented according to the + following work: + + @article{DBLP:journals/corr/abs-1804-09541, + author = {Minjoon Seo and + Aniruddha Kembhavi and + Ali Farhadi and + Hannaneh Hajishirzi}, + title = {Bidirectional Attention Flow for Machine Comprehension}, + year = {2016}, + url = {https://arxiv.org/abs/1611.01603} + } """ def __init__(self, word_vocab, char_vocab, options, prefix=None, params=None): super().__init__(prefix=prefix, params=params) self._options = options with self.name_scope(): - # contextual embedding layer self.ctx_embedding = BiDAFEmbedding(options.batch_size, word_vocab, char_vocab, diff --git a/scripts/question_answering/question_id_mapper.py b/scripts/question_answering/question_id_mapper.py index 5f739b889c..0b8307d77f 100644 --- a/scripts/question_answering/question_id_mapper.py +++ b/scripts/question_answering/question_id_mapper.py @@ -21,6 +21,9 @@ class QuestionIdMapper: + """Stores mapping between question id and context of SQuAD dataset + + """ def __init__(self, dataset): self._question_id_to_context = {item[1]: item[3] for item in dataset} self._question_id_to_idx = {item[1]: item[0] for item in dataset} @@ -28,12 +31,33 @@ def __init__(self, dataset): @property def question_id_to_context(self): + """Provides question Id to context map + + Returns + ------- + map: Dict + Question Id to Context map + """ return self._question_id_to_context @property def idx_to_question_id(self): + """Provides record index to question Id map + + Returns + ------- + map: Dict + Record index to question Id map + """ return self._idx_to_question_id @property def question_id_to_idx(self): + """Provides question Id to record index map + + Returns + ------- + map: Dict + Question Id to record index map + """ return self._question_id_to_idx diff --git a/scripts/question_answering/tokenizer.py b/scripts/question_answering/tokenizer.py index 31a1953499..98c18bce57 100644 --- a/scripts/question_answering/tokenizer.py +++ b/scripts/question_answering/tokenizer.py @@ -17,21 +17,20 @@ # specific language governing permissions and limitations # under the License. import re - import nltk class BiDAFTokenizer: - def __init__(self): - pass - + """Tokenizer that is used for preprocessing data for BiDAF model. It applies basic tokenizer + and some extra preprocessing steps making data ready to be used for training BiDAF + """ def __call__(self, sample, lower_case=False): - """ + """Process single record Parameters ---------- sample: str - The sentence to tokenize + The record to tokenize Returns ------- @@ -52,6 +51,18 @@ def __call__(self, sample, lower_case=False): @staticmethod def _process_tokens(temp_tokens): + """Process tokens by splitting them if split symbol is encountered + + Parameters + ---------- + temp_tokens: list[str] + Tokens to be processed + + Returns + ------- + tokens : list[str] + List of updated tokens + """ tokens = [] splitters = ("-", "\u2212", "\u2014", "\u2013", "/", "~", '"', "'", "\u201C", "\u2019", "\u201D", "\u2018", "\u00B0") diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index 555380918d..8edb6b54d2 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -27,9 +27,6 @@ import logging import pickle -import argparse -import numpy as np -import random from time import time import mxnet as mx @@ -64,8 +61,8 @@ def transform_dataset(dataset, vocab_provider, options, enable_filtering=False): Returns ------- - data : Tuple - A tuple of dataset, QuestionIdMapper and original json data for evaluation + data : SimpleDataset + Transformed dataset """ tokenizer = vocab_provider.get_tokenizer() transformer = SQuADTransform(vocab_provider, options.q_max_len, @@ -78,37 +75,28 @@ def transform_dataset(dataset, vocab_provider, options, enable_filtering=False): for i, record in enumerate(dataset): if enable_filtering: tokenized_question = tokenizer(record[2], lower_case=True) - # we don't need to dispose of context as long as the answer is still - # present in the context after it is trimmed - # tokenized_context = tokenizer(record[3], lower_case=True) - # - # if len(tokenized_context) > options.ctx_max_len: - # long_context += 1 - # continue - - # but we don't know if the question is still meaningful + if len(tokenized_question) > options.q_max_len: long_question += 1 continue transformed_record = transformer(*record) - # if answer end index is after ctx_max_len token or - # it is after q_max_len token we do not use this record + # if answer end index is after ctx_max_len token we do not use this record if enable_filtering and transformed_record[6][0][1] >= options.ctx_max_len: continue transformed_records.append(transformed_record) processed_dataset = SimpleDataset(transformed_records) - print("{}/{} records. Too long context {}, too long query {}".format( + print("{}/{} records left. Too long context {}, too long query {}".format( len(processed_dataset), i + 1, long_context, long_question)) return processed_dataset def get_record_per_answer_span(processed_dataset, options): - """Each record has multiple answers and for training purposes it is better to increase number of - records by creating a record per each answer. + """Each record might have multiple answers and for training purposes it is better to increase + number of records by creating a record per each answer. Parameters ---------- @@ -184,12 +172,15 @@ def get_context(options): options : `Namespace` Command arguments + Returns + ------- + ctx : list[Context] + List of contexts """ ctx = [] if options.gpu is None: ctx.append(mx.cpu(0)) - ctx.append(mx.cpu(1)) print('Use CPU') else: indices = options.gpu.split(',') @@ -404,11 +395,18 @@ def get_learning_rate_per_iteration(iteration, options): """Returns learning rate based on current iteration. Used to implement learning rate warm up technique - :param int iteration: Number of iteration - :param NameSpace options: Training options - :return float: learning rate - """ + Parameters + ---------- + iteration : `int` + Number of iteration + options : `Namespace` + Training options + Returns + ------- + learning_rate : float + Learning rate + """ if options.resume_training: return options.lr @@ -418,10 +416,19 @@ def get_learning_rate_per_iteration(iteration, options): def get_gradients(model, ctx, options): """Get gradients and apply gradient decay to all layers if required. - :param BiDAFModel model: Model in training - :param ctx: Contexts - :param NameSpace options: Training options - :return: Array of gradients + Parameters + ---------- + model : `BiDAFModel` + Model in training + ctx : `Context` + Training context + options : `Namespace` + Training options + + Returns + ------- + gradients : list + List of gradients """ gradients = [] @@ -446,13 +453,24 @@ def reset_embedding_gradients(model, ctx): """Gradients for embedding layer doesn't need to be trained. We train only UNK token of embedding if required. - :param BiDAFModel model: Model in training - :param ctx: Contexts of training + Parameters + ---------- + model : `BiDAFModel` + Model in training + ctx : `Context` + Training context """ model.ctx_embedding._word_embedding.weight.grad(ctx=ctx)[1:] = 0 def is_fixed_embedding_layer(name): + """Check if this is an embedding layer which parameters are supposed to be fixed + + Parameters + ---------- + name : `str` + Layer name to check + """ return True if "predefined_embedding_layer" in name else False @@ -468,7 +486,6 @@ def save_model_parameters(net, epoch, options): options : `Namespace` Saving arguments """ - if not os.path.exists(options.save_dir): os.mkdir(options.save_dir) @@ -540,12 +557,24 @@ def load_transformed_dataset(path): ---------- path : `str` Loading path + + Returns + ------- + processed_dataset : SimpleDataset + Transformed dataset """ processed_dataset = pickle.load(open(path, "rb")) return processed_dataset def run_preprocess_mode(options): + """Run program in data preprocessing mode + + Parameters + ---------- + options : `Namespace` + Data preprocessing arguments + """ # we use both datasets to create proper vocab dataset_train = SQuAD(segment='train') dataset_dev = SQuAD(segment='dev') @@ -561,6 +590,13 @@ def run_preprocess_mode(options): def run_training_mode(options): + """Run program in data training mode + + Parameters + ---------- + options : `Namespace` + Model training parameters + """ dataset = SQuAD(segment='train') dataset_val = SQuAD(segment='dev') vocab_provider = VocabProvider([dataset, dataset_val], options) @@ -594,6 +630,18 @@ def run_training_mode(options): def run_evaluate_mode(options, existing_net=None, existing_ema=None): + """Run program in evaluating mode + + Parameters + ---------- + options : `Namespace` + Model evaluation arguments + + Returns + ------- + result : dict + Dictionary with exact_match and F1 scores + """ train_dataset = SQuAD(segment='train') dataset = SQuAD(segment='dev') diff --git a/scripts/question_answering/utils.py b/scripts/question_answering/utils.py index e212798979..cf4f95ab4c 100644 --- a/scripts/question_answering/utils.py +++ b/scripts/question_answering/utils.py @@ -114,11 +114,11 @@ def get_args(): parser.add_argument('--lr_warmup_steps', type=int, default=0, help='Defines how many iterations to spend on warming up learning rate') parser.add_argument('--clip', type=float, default=0, help='gradient clipping') - parser.add_argument('--weight_decay', type=float, default=0.0005, + parser.add_argument('--weight_decay', type=float, default=0, help='Weight decay for parameter updates') parser.add_argument('--log_interval', type=int, default=100, metavar='N', help='Report interval applied to last epoch only') - parser.add_argument('--early_stop', type=int, default=4, + parser.add_argument('--early_stop', type=int, default=9, help='Apply early stopping for the last epoch. Stop after # of consequent ' '# of times F1 is lower than max. Should be used with log_interval') parser.add_argument('--resume_training', type=int, default=0, From 634b06996286ec7e081d9aa1ef821210c0816e2e Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Thu, 25 Oct 2018 16:12:05 -0700 Subject: [PATCH 34/43] Return static_alloc --- scripts/question_answering/train_question_answering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index 8edb6b54d2..d3432a4ebc 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -615,7 +615,7 @@ def run_training_mode(options): net = BiDAFModel(word_vocab, char_vocab, options, prefix="bidaf") net.cast(options.precision) net.initialize(init.Xavier(), ctx=ctx) - net.hybridize() + net.hybridize(static_alloc=True) if options.grad_req_add_mode: net.collect_params().setattr('grad_req', 'add') From c4f4a3c73b3100c9f34d9053e275db9186dad866 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Tue, 30 Oct 2018 14:44:52 -0700 Subject: [PATCH 35/43] Multigpu and arbitrary batch support for evals --- .../performance_evaluator.py | 43 ++++++++++++------- scripts/question_answering/utils.py | 22 ++++++++++ 2 files changed, 50 insertions(+), 15 deletions(-) diff --git a/scripts/question_answering/performance_evaluator.py b/scripts/question_answering/performance_evaluator.py index 38d2266245..ad55a0495c 100644 --- a/scripts/question_answering/performance_evaluator.py +++ b/scripts/question_answering/performance_evaluator.py @@ -19,11 +19,12 @@ """Performance evaluator - a proxy class used for plugging in official validation script""" import multiprocessing -from mxnet import nd, gluon +from mxnet import nd, gluon, cpu from mxnet.gluon.data import DataLoader, ArrayDataset from scripts.question_answering.data_processing import SQuADTransform from scripts.question_answering.official_squad_eval_script import evaluate +from scripts.question_answering.utils import extend_to_batch_size class PerformanceEvaluator: @@ -56,14 +57,16 @@ def evaluate_performance(self, net, ctx, options): # Allows to ensure that start index is always <= than end index for c in ctx: - answer_mask_matrix = nd.zeros(shape=(1, options.ctx_max_len, options.ctx_max_len), ctx=c) + answer_mask_matrix = nd.zeros(shape=(1, options.ctx_max_len, options.ctx_max_len), + ctx=cpu(0)) for idx in range(options.answer_max_len): answer_mask_matrix += nd.eye(N=options.ctx_max_len, M=options.ctx_max_len, - k=idx, ctx=c) + k=idx, ctx=cpu(0)) eval_dataset = ArrayDataset([(self._mapper.question_id_to_idx[r[1]], r[2], r[3], r[4], r[5]) for r in self._evaluation_dataset]) - eval_dataloader = DataLoader(eval_dataset, batch_size=options.batch_size, + eval_dataloader = DataLoader(eval_dataset, + batch_size=len(ctx) * options.batch_size, last_batch='keep', pin_memory=True, num_workers=(multiprocessing.cpu_count() - len(ctx) - 2)) @@ -71,11 +74,16 @@ def evaluate_performance(self, net, ctx, options): for i, data in enumerate(eval_dataloader): record_index, q_words, ctx_words, q_chars, ctx_chars = data - record_index = record_index.astype(options.precision) - q_words = q_words.astype(options.precision) - ctx_words = ctx_words.astype(options.precision) - q_chars = q_chars.astype(options.precision) - ctx_chars = ctx_chars.astype(options.precision) + record_index = extend_to_batch_size(options.batch_size * len(ctx), + record_index.astype(options.precision), -1) + q_words = extend_to_batch_size(options.batch_size * len(ctx), + q_words.astype(options.precision)) + ctx_words = extend_to_batch_size(options.batch_size * len(ctx), + ctx_words.astype(options.precision)) + q_chars = extend_to_batch_size(options.batch_size * len(ctx), + q_chars.astype(options.precision)) + ctx_chars = extend_to_batch_size(options.batch_size * len(ctx), + ctx_chars.astype(options.precision)) record_index = gluon.utils.split_and_load(record_index, ctx, even_split=False) q_words = gluon.utils.split_and_load(q_words, ctx, even_split=False) @@ -103,20 +111,25 @@ def evaluate_performance(self, net, ctx, options): q_embedding_begin_state, m_layer_begin_state, o_layer_begin_state) - outs.append((begin, end)) + outs.append((ri.as_in_context(cpu(0)), + begin.as_in_context(cpu(0)), + end.as_in_context(cpu(0)))) for out in outs: - start = out[0].softmax(axis=1) - end = out[1].softmax(axis=1) + ri = out[0] + start = out[1].softmax(axis=1) + end = out[2].softmax(axis=1) start_end_span = PerformanceEvaluator._get_indices(start, end, answer_mask_matrix) # iterate over batches - for idx, start_end in zip(data[0], start_end_span): + for idx, start_end in zip(ri, start_end_span): idx = int(idx.asscalar()) start = int(start_end[0].asscalar()) end = int(start_end[1].asscalar()) - question_id = self._mapper.idx_to_question_id[idx] - pred[question_id] = (start, end, self.get_text_result(idx, (start, end))) + + if idx in self._mapper.idx_to_question_id: + question_id = self._mapper.idx_to_question_id[idx] + pred[question_id] = (start, end, self.get_text_result(idx, (start, end))) if options.save_prediction_path: with open(options.save_prediction_path, "w") as f: diff --git a/scripts/question_answering/utils.py b/scripts/question_answering/utils.py index cf4f95ab4c..878225c8f3 100644 --- a/scripts/question_answering/utils.py +++ b/scripts/question_answering/utils.py @@ -158,6 +158,28 @@ def get_very_negative_number(): return -1e30 +def extend_to_batch_size(batch_size, prototype, fill_value=0): + """Provides NDArray, which consist of prototype NDArray and NDArray filled with fill_value to + batch_size number of items. New NDArray appended to batch dimension (dim=0). + + Parameters + ---------- + batch_size: ``int`` + Expected value for batch_size dimension (dim=0). + prototype: ``NDArray`` + NDArray to be extended of shape (batch_size, ...) + fill_value: ``float`` + Value to use for filling + """ + if batch_size == prototype.shape[0]: + return prototype + + new_shape = (batch_size - prototype.shape[0], ) + prototype.shape[1:] + dummy_elements = nd.full(val=fill_value, shape=new_shape, dtype=prototype.dtype, + ctx=prototype.context) + return nd.concat(prototype, dummy_elements, dim=0) + + def get_combined_dim(combination, tensor_dims): """ For use with :func:`combine_tensors`. This function computes the resultant dimension when From 6ee773293a738191b9c3f9a9d72dc66eaf8cd983 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Thu, 1 Nov 2018 16:12:15 -0700 Subject: [PATCH 36/43] Add terminate training if need to reach F1 only --- .../question_answering/train_question_answering.py | 11 ++++++++++- scripts/question_answering/utils.py | 5 +++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index d3432a4ebc..fe00ae133d 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -343,7 +343,6 @@ def run_training(net, dataloader, ctx, options): options.log_interval > 0 and \ iteration > 0 and iteration % options.log_interval == 0: evaluate_options = copy.deepcopy(options) - evaluate_options.batch_size = 10 evaluate_options.epochs = iteration result = run_evaluate_mode(evaluate_options, net, ema) @@ -388,6 +387,16 @@ def run_training(net, dataloader, ctx, options): save_ema_parameters(ema, e, options) save_trainer_parameters(trainer, e, options) + if options.terminate_training_on_reaching_F1_threshold: + evaluate_options = copy.deepcopy(options) + evaluate_options.epochs = e + result = run_evaluate_mode(evaluate_options, net, ema) + + if result["f1"] >= options.terminate_training_on_reaching_F1_threshold: + print("Finishing training on {} epoch, because dev F1 score is >= required {}. {}" + .format(e, options.terminate_training_on_reaching_F1_threshold, result)) + break + print("Training time {:6.2f} seconds".format(time() - train_start)) diff --git a/scripts/question_answering/utils.py b/scripts/question_answering/utils.py index 878225c8f3..3e0e5e4ce5 100644 --- a/scripts/question_answering/utils.py +++ b/scripts/question_answering/utils.py @@ -123,6 +123,11 @@ def get_args(): '# of times F1 is lower than max. Should be used with log_interval') parser.add_argument('--resume_training', type=int, default=0, help='Resume training from this epoch number') + parser.add_argument('--terminate_training_on_reaching_F1_threshold', type=float, default=0, + help='Some tasks, like DAWNBenchmark requires to minimize training time ' + 'while reaching a particular F1 metric. This parameter controls if ' + 'training should be terminated as soon as F1 is reached to minimize ' + 'training time and cost. It would force to do evaluation every epoch.') parser.add_argument('--save_dir', type=str, default='out_dir', help='directory path to save the final model and training log') parser.add_argument('--word_vocab_path', type=str, default=None, From 11e66c8e5aae47efad08bdcd86e573c832eca948 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Wed, 7 Nov 2018 13:31:49 -0800 Subject: [PATCH 37/43] Code review comments addressed --- scripts/question_answering/bidaf.py | 2 +- .../train_question_answering.py | 97 ++++++++++++++++++- scripts/question_answering/utils.py | 84 ---------------- 3 files changed, 94 insertions(+), 89 deletions(-) diff --git a/scripts/question_answering/bidaf.py b/scripts/question_answering/bidaf.py index a9f573fa74..df1c0271e2 100644 --- a/scripts/question_answering/bidaf.py +++ b/scripts/question_answering/bidaf.py @@ -10,7 +10,7 @@ # # http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writinConvolutionalEncoderg, +# Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index fe00ae133d..7f8476be81 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -16,6 +16,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import argparse import math import copy @@ -42,7 +43,7 @@ from scripts.question_answering.question_answering import * from scripts.question_answering.question_id_mapper import QuestionIdMapper from scripts.question_answering.tokenizer import BiDAFTokenizer -from scripts.question_answering.utils import logging_config, get_args +from scripts.question_answering.utils import logging_config def transform_dataset(dataset, vocab_provider, options, enable_filtering=False): @@ -232,12 +233,14 @@ def run_training(net, dataloader, ctx, options): max_dev_f1 = -1 max_iteration = -1 early_stop_tries = 0 + records_per_epoch_count = 0 print("Starting training...") for e in range(0 if not options.resume_training else options.resume_training, options.epochs): avg_loss *= 0 # Zero average loss of each epoch + records_per_epoch_count = 0 ctx_embedding_begin_state_list = net.ctx_embedding.begin_state(ctx) q_embedding_begin_state_list = net.ctx_embedding.begin_state(ctx) @@ -250,6 +253,7 @@ def run_training(net, dataloader, ctx, options): e_start = time() record_index, q_words, ctx_words, q_chars, ctx_chars = data + records_per_epoch_count += record_index.shape[0] record_index = record_index.astype(options.precision) q_words = q_words.astype(options.precision) @@ -378,10 +382,10 @@ def run_training(net, dataloader, ctx, options): avg_loss_scalar = avg_loss.asscalar() epoch_time = time() - e_start - print("\tEPOCH {:2}: train loss {:6.4f} | batch {:4} | lr {:5.3f} | " - "Time per epoch {:5.2f} seconds" + print("\tEPOCH {:2}: train loss {:6.4f} | batch {:4} | lr {:5.3f} " + "| throughtput {:5.3f} of samples/sec | Time per epoch {:5.2f} seconds" .format(e, avg_loss_scalar, options.batch_size, trainer.learning_rate, - epoch_time)) + records_per_epoch_count / epoch_time, epoch_time)) save_model_parameters(net, e, options) save_ema_parameters(ema, e, options) @@ -696,6 +700,91 @@ def run_evaluate_mode(options, existing_net=None, existing_ema=None): return evaluator.evaluate_performance(net, ctx, options) +def get_args(): + """Get console arguments + """ + parser = argparse.ArgumentParser(description='Question Answering example using BiDAF & SQuAD') + parser.add_argument('--preprocess', default=False, action='store_true', + help='Preprocess dataset') + parser.add_argument('--train', default=False, action='store_true', + help='Run training') + parser.add_argument('--evaluate', default=False, action='store_true', + help='Run evaluation on dev dataset') + parser.add_argument('--preprocessed_dataset_path', type=str, + default="preprocessed_dataset.p", help='Path to preprocessed dataset') + parser.add_argument('--preprocessed_val_dataset_path', type=str, + default="preprocessed_val_dataset.p", + help='Path to preprocessed validation dataset') + parser.add_argument('--epochs', type=int, default=12, help='Upper epoch limit') + parser.add_argument('--embedding_size', type=int, default=100, + help='Dimension of the word embedding') + parser.add_argument('--dropout', type=float, default=0.2, + help='dropout applied to layers (0 = no dropout)') + parser.add_argument('--ctx_embedding_num_layers', type=int, default=2, + help='Number of layers in Contextual embedding layer of BiDAF') + parser.add_argument('--highway_num_layers', type=int, default=2, + help='Number of layers in Highway layer of BiDAF') + parser.add_argument('--modeling_num_layers', type=int, default=2, + help='Number of layers in Modeling layer of BiDAF') + parser.add_argument('--output_num_layers', type=int, default=1, + help='Number of layers in Output layer of BiDAF') + parser.add_argument('--batch_size', type=int, default=60, help='Batch size') + parser.add_argument('--ctx_max_len', type=int, default=400, help='Maximum length of a context') + parser.add_argument('--q_max_len', type=int, default=30, help='Maximum length of a question') + parser.add_argument('--word_max_len', type=int, default=16, help='Maximum characters in a word') + parser.add_argument('--answer_max_len', type=int, default=30, help='Maximum tokens in answer') + parser.add_argument('--optimizer', type=str, default='adadelta', help='optimization algorithm') + parser.add_argument('--lr', type=float, default=0.5, help='Initial learning rate') + parser.add_argument('--rho', type=float, default=0.9, + help='Adadelta decay rate for both squared gradients and delta.') + parser.add_argument('--lr_warmup_steps', type=int, default=0, + help='Defines how many iterations to spend on warming up learning rate') + parser.add_argument('--clip', type=float, default=0, help='gradient clipping') + parser.add_argument('--weight_decay', type=float, default=0, + help='Weight decay for parameter updates') + parser.add_argument('--log_interval', type=int, default=100, metavar='N', + help='Report interval applied to last epoch only') + parser.add_argument('--early_stop', type=int, default=9, + help='Apply early stopping for the last epoch. Stop after # of consequent ' + '# of times F1 is lower than max. Should be used with log_interval') + parser.add_argument('--resume_training', type=int, default=0, + help='Resume training from this epoch number') + parser.add_argument('--terminate_training_on_reaching_F1_threshold', type=float, default=0, + help='Some tasks, like DAWNBenchmark requires to minimize training time ' + 'while reaching a particular F1 metric. This parameter controls if ' + 'training should be terminated as soon as F1 is reached to minimize ' + 'training time and cost. It would force to do evaluation every epoch.') + parser.add_argument('--save_dir', type=str, default='out_dir', + help='directory path to save the final model and training log') + parser.add_argument('--word_vocab_path', type=str, default=None, + help='Path to preprocessed word-level vocabulary') + parser.add_argument('--char_vocab_path', type=str, default=None, + help='Path to preprocessed character-level vocabulary') + parser.add_argument('--gpu', type=str, default=None, + help='Coma-separated ids of the gpu to use. Empty means to use cpu.') + parser.add_argument('--train_unk_token', default=False, action='store_true', + help='Should train unknown token of embedding') + parser.add_argument('--precision', type=str, default='float32', choices=['float16', 'float32'], + help='Use float16 or float32 precision') + parser.add_argument('--filter_long_context', default=True, action='store_false', + help='Filter contexts if the answer is after ctx_max_len') + parser.add_argument('--save_prediction_path', type=str, default='', + help='Path to save predictions') + parser.add_argument('--use_multiprecision_in_optimizer', default=True, action='store_false', + help='When using float16, shall optimizer use multiprecision.') + parser.add_argument('--use_exponential_moving_average', default=True, action='store_false', + help='Should averaged copy of parameters been stored and used ' + 'during evaluation.') + parser.add_argument('--exponential_moving_average_weight_decay', type=float, default=0.999, + help='Weight decay used in exponential moving average') + parser.add_argument('--grad_req_add_mode', type=int, default=0, + help='Enable rolling gradient mode, where batch size is always 1 and ' + 'gradients are accumulated using single GPU') + + args = parser.parse_args() + return args + + if __name__ == "__main__": args = get_args() args.batch_size = int(args.batch_size / len(get_context(args))) diff --git a/scripts/question_answering/utils.py b/scripts/question_answering/utils.py index 3e0e5e4ce5..c4413fd44b 100644 --- a/scripts/question_answering/utils.py +++ b/scripts/question_answering/utils.py @@ -74,90 +74,6 @@ def logging_config(folder=None, name=None, level=logging.DEBUG, console_level=lo return folder -def get_args(): - """Get console arguments - """ - parser = argparse.ArgumentParser(description='Question Answering example using BiDAF & SQuAD') - parser.add_argument('--preprocess', default=False, action='store_true', - help='Preprocess dataset') - parser.add_argument('--train', default=False, action='store_true', - help='Run training') - parser.add_argument('--evaluate', default=False, action='store_true', - help='Run evaluation on dev dataset') - parser.add_argument('--preprocessed_dataset_path', type=str, - default="preprocessed_dataset.p", help='Path to preprocessed dataset') - parser.add_argument('--preprocessed_val_dataset_path', type=str, - default="preprocessed_val_dataset.p", - help='Path to preprocessed validation dataset') - parser.add_argument('--epochs', type=int, default=12, help='Upper epoch limit') - parser.add_argument('--embedding_size', type=int, default=100, - help='Dimension of the word embedding') - parser.add_argument('--dropout', type=float, default=0.2, - help='dropout applied to layers (0 = no dropout)') - parser.add_argument('--ctx_embedding_num_layers', type=int, default=2, - help='Number of layers in Contextual embedding layer of BiDAF') - parser.add_argument('--highway_num_layers', type=int, default=2, - help='Number of layers in Highway layer of BiDAF') - parser.add_argument('--modeling_num_layers', type=int, default=2, - help='Number of layers in Modeling layer of BiDAF') - parser.add_argument('--output_num_layers', type=int, default=1, - help='Number of layers in Output layer of BiDAF') - parser.add_argument('--batch_size', type=int, default=60, help='Batch size') - parser.add_argument('--ctx_max_len', type=int, default=400, help='Maximum length of a context') - parser.add_argument('--q_max_len', type=int, default=30, help='Maximum length of a question') - parser.add_argument('--word_max_len', type=int, default=16, help='Maximum characters in a word') - parser.add_argument('--answer_max_len', type=int, default=30, help='Maximum tokens in answer') - parser.add_argument('--optimizer', type=str, default='adadelta', help='optimization algorithm') - parser.add_argument('--lr', type=float, default=0.5, help='Initial learning rate') - parser.add_argument('--rho', type=float, default=0.9, - help='Adadelta decay rate for both squared gradients and delta.') - parser.add_argument('--lr_warmup_steps', type=int, default=0, - help='Defines how many iterations to spend on warming up learning rate') - parser.add_argument('--clip', type=float, default=0, help='gradient clipping') - parser.add_argument('--weight_decay', type=float, default=0, - help='Weight decay for parameter updates') - parser.add_argument('--log_interval', type=int, default=100, metavar='N', - help='Report interval applied to last epoch only') - parser.add_argument('--early_stop', type=int, default=9, - help='Apply early stopping for the last epoch. Stop after # of consequent ' - '# of times F1 is lower than max. Should be used with log_interval') - parser.add_argument('--resume_training', type=int, default=0, - help='Resume training from this epoch number') - parser.add_argument('--terminate_training_on_reaching_F1_threshold', type=float, default=0, - help='Some tasks, like DAWNBenchmark requires to minimize training time ' - 'while reaching a particular F1 metric. This parameter controls if ' - 'training should be terminated as soon as F1 is reached to minimize ' - 'training time and cost. It would force to do evaluation every epoch.') - parser.add_argument('--save_dir', type=str, default='out_dir', - help='directory path to save the final model and training log') - parser.add_argument('--word_vocab_path', type=str, default=None, - help='Path to preprocessed word-level vocabulary') - parser.add_argument('--char_vocab_path', type=str, default=None, - help='Path to preprocessed character-level vocabulary') - parser.add_argument('--gpu', type=str, default=None, - help='Coma-separated ids of the gpu to use. Empty means to use cpu.') - parser.add_argument('--train_unk_token', default=False, action='store_true', - help='Should train unknown token of embedding') - parser.add_argument('--precision', type=str, default='float32', choices=['float16', 'float32'], - help='Use float16 or float32 precision') - parser.add_argument('--filter_long_context', default=True, action='store_false', - help='Filter contexts if the answer is after ctx_max_len') - parser.add_argument('--save_prediction_path', type=str, default='', - help='Path to save predictions') - parser.add_argument('--use_multiprecision_in_optimizer', default=True, action='store_false', - help='When using float16, shall optimizer use multiprecision.') - parser.add_argument('--use_exponential_moving_average', default=True, action='store_false', - help='Should averaged copy of parameters been stored and used ' - 'during evaluation.') - parser.add_argument('--exponential_moving_average_weight_decay', type=float, default=0.999, - help='Weight decay used in exponential moving average') - parser.add_argument('--grad_req_add_mode', type=int, default=0, - help='Enable rolling gradient mode, where batch size is always 1 and ' - 'gradients are accumulated using single GPU') - - args = parser.parse_args() - return args - def get_very_negative_number(): return -1e30 From 5bddd4851bd086d1f4cf53928a228ea02d02790c Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Wed, 7 Nov 2018 15:15:26 -0800 Subject: [PATCH 38/43] FP16 removed and tests are fixed --- scripts/question_answering/bidaf.py | 8 - .../performance_evaluator.py | 36 +--- .../question_answering/question_answering.py | 108 ++-------- .../train_question_answering.py | 39 +--- scripts/tests/test_question_answering.py | 191 ++++++------------ 5 files changed, 94 insertions(+), 288 deletions(-) diff --git a/scripts/question_answering/bidaf.py b/scripts/question_answering/bidaf.py index df1c0271e2..b06d7ac4eb 100644 --- a/scripts/question_answering/bidaf.py +++ b/scripts/question_answering/bidaf.py @@ -36,7 +36,6 @@ def __init__(self, passage_length, question_length, encoding_dim, - precision, **kwargs): super(BidirectionalAttentionFlow, self).__init__(**kwargs) @@ -44,18 +43,11 @@ def __init__(self, self._passage_length = passage_length self._question_length = question_length self._encoding_dim = encoding_dim - self._precision = precision def _get_big_negative_value(self): - if self._precision == 'float16': - return np.finfo(np.float16).min - else: return np.finfo(np.float32).min def _get_small_positive_value(self): - if self._precision == 'float16': - return np.finfo(np.float16).eps - else: return np.finfo(np.float32).eps def hybrid_forward(self, F, passage_question_similarity, diff --git a/scripts/question_answering/performance_evaluator.py b/scripts/question_answering/performance_evaluator.py index ad55a0495c..afe4840f8d 100644 --- a/scripts/question_answering/performance_evaluator.py +++ b/scripts/question_answering/performance_evaluator.py @@ -74,16 +74,11 @@ def evaluate_performance(self, net, ctx, options): for i, data in enumerate(eval_dataloader): record_index, q_words, ctx_words, q_chars, ctx_chars = data - record_index = extend_to_batch_size(options.batch_size * len(ctx), - record_index.astype(options.precision), -1) - q_words = extend_to_batch_size(options.batch_size * len(ctx), - q_words.astype(options.precision)) - ctx_words = extend_to_batch_size(options.batch_size * len(ctx), - ctx_words.astype(options.precision)) - q_chars = extend_to_batch_size(options.batch_size * len(ctx), - q_chars.astype(options.precision)) - ctx_chars = extend_to_batch_size(options.batch_size * len(ctx), - ctx_chars.astype(options.precision)) + record_index = extend_to_batch_size(options.batch_size * len(ctx), record_index, -1) + q_words = extend_to_batch_size(options.batch_size * len(ctx), q_words) + ctx_words = extend_to_batch_size(options.batch_size * len(ctx), ctx_words) + q_chars = extend_to_batch_size(options.batch_size * len(ctx), q_chars) + ctx_chars = extend_to_batch_size(options.batch_size * len(ctx), ctx_chars) record_index = gluon.utils.split_and_load(record_index, ctx, even_split=False) q_words = gluon.utils.split_and_load(q_words, ctx, even_split=False) @@ -91,26 +86,11 @@ def evaluate_performance(self, net, ctx, options): q_chars = gluon.utils.split_and_load(q_chars, ctx, even_split=False) ctx_chars = gluon.utils.split_and_load(ctx_chars, ctx, even_split=False) - ctx_embedding_begin_state_list = net.ctx_embedding.begin_state(ctx) - q_embedding_begin_state_list = net.ctx_embedding.begin_state(ctx) - m_layer_begin_state_list = net.modeling_layer.begin_state(ctx) - o_layer_begin_state_list = net.output_layer.begin_state(ctx) - outs = [] - for ri, qw, cw, qc, cc, ctx_embedding_begin_state, \ - q_embedding_begin_state, m_layer_begin_state, \ - o_layer_begin_state in zip(record_index, q_words, ctx_words, - q_chars, ctx_chars, - ctx_embedding_begin_state_list, - q_embedding_begin_state_list, - m_layer_begin_state_list, - o_layer_begin_state_list): - begin, end = net(qw, cw, qc, cc, - ctx_embedding_begin_state, - q_embedding_begin_state, - m_layer_begin_state, - o_layer_begin_state) + for ri, qw, cw, qc, cc in zip(record_index, q_words, ctx_words, + q_chars, ctx_chars): + begin, end = net(qw, cw, qc, cc) outs.append((ri.as_in_context(cpu(0)), begin.as_in_context(cpu(0)), end.as_in_context(cpu(0)))) diff --git a/scripts/question_answering/question_answering.py b/scripts/question_answering/question_answering.py index 4d955fadb4..5c499abe89 100644 --- a/scripts/question_answering/question_answering.py +++ b/scripts/question_answering/question_answering.py @@ -42,13 +42,12 @@ class BiDAFEmbedding(HybridBlock): """ def __init__(self, batch_size, word_vocab, char_vocab, max_seq_len, contextual_embedding_nlayers=2, highway_nlayers=2, embedding_size=100, - dropout=0.2, precision='float32', prefix=None, params=None): + dropout=0.2, prefix=None, params=None): super(BiDAFEmbedding, self).__init__(prefix=prefix, params=params) self._word_vocab = word_vocab self._batch_size = batch_size self._max_seq_len = max_seq_len - self._precision = precision self._embedding_size = embedding_size with self.name_scope(): @@ -88,17 +87,7 @@ def init_embeddings(self, lock_gradients): if lock_gradients: self._word_embedding.collect_params().setattr('grad_req', 'null') - def begin_state(self, ctx, batch_sizes=None): - if batch_sizes is None: - batch_sizes = [self._batch_size] * len(ctx) - - state_list = [self._contextual_embedding.begin_state(b, - dtype=self._precision, - ctx=c) for c, b in zip(ctx, - batch_sizes)] - return state_list - - def hybrid_forward(self, F, w, c, contextual_embedding_state, *args): + def hybrid_forward(self, F, w, c, *args): word_embedded = self._word_embedding(w) char_level_data = self._char_dense_embedding(c) char_level_data = self._dropout(char_level_data) @@ -122,8 +111,7 @@ def highway(token_of_all_batches, _): highway_output, _ = F.contrib.foreach(highway, highway_input, []) # Transpose to TNC - default for LSTM - ce_output, ce_state = self._contextual_embedding(highway_output, - contextual_embedding_state) + ce_output = self._contextual_embedding(highway_output) return ce_output @@ -152,43 +140,17 @@ class BiDAFModelingLayer(HybridBlock): Shared Parameters for this `Block`. """ def __init__(self, batch_size, input_dim=100, nlayers=2, biflag=True, - dropout=0.2, precision='float32', prefix=None, params=None): + dropout=0.2, prefix=None, params=None): super(BiDAFModelingLayer, self).__init__(prefix=prefix, params=params) self._batch_size = batch_size - self._precision = precision with self.name_scope(): self._modeling_layer = LSTM(hidden_size=input_dim, num_layers=nlayers, dropout=dropout, bidirectional=biflag, input_size=800) - def begin_state(self, ctx, batch_sizes=None): - """Provides begin state for the layer's modeling_layer block - - Parameters - ---------- - ctx: list[Context] - List of contexts to be used - - batch_sizes: list[int] - List of batch-sizes per context - - Returns - ------- - state_list: list - List of states - """ - if batch_sizes is None: - batch_sizes = [self._batch_size] * len(ctx) - - state_list = [self._modeling_layer.begin_state(b, - dtype=self._precision, - ctx=c) for c, b in zip(ctx, - batch_sizes)] - return state_list - - def hybrid_forward(self, F, x, state, *args): - out, _ = self._modeling_layer(x, state) + def hybrid_forward(self, F, x, *args): + out = self._modeling_layer(x) return out @@ -223,11 +185,10 @@ class BiDAFOutputLayer(HybridBlock): Shared Parameters for this `Block`. """ def __init__(self, batch_size, span_start_input_dim=100, nlayers=1, biflag=True, - dropout=0.2, precision='float32', prefix=None, params=None): + dropout=0.2, prefix=None, params=None): super(BiDAFOutputLayer, self).__init__(prefix=prefix, params=params) self._batch_size = batch_size - self._precision = precision with self.name_scope(): self._dropout = nn.Dropout(rate=dropout) @@ -243,32 +204,7 @@ def __init__(self, batch_size, span_start_input_dim=100, nlayers=1, biflag=True, self._end_index_m = nn.Dense(units=1, in_units=2 * span_start_input_dim, flatten=False) - def begin_state(self, ctx, batch_sizes=None): - """Provides begin state for the layer's end_index_lstm block - - Parameters - ---------- - ctx: list[Context] - List of contexts to be used - - batch_sizes: list[int] - List of batch-sizes per context - - Returns - ------- - state_list: list - List of states - """ - if batch_sizes is None: - batch_sizes = [self._batch_size] * len(ctx) - - state_list = [self._end_index_lstm.begin_state(b, - dtype=self._precision, - ctx=c) for c, b in zip(ctx, - batch_sizes)] - return state_list - - def hybrid_forward(self, F, x, m, mask, state, *args): # pylint: disable=arguments-differ + def hybrid_forward(self, F, x, m, mask, *args): # pylint: disable=arguments-differ # setting batch size as the first dimension x = F.transpose(x, axes=(1, 0, 2)) @@ -276,7 +212,7 @@ def hybrid_forward(self, F, x, m, mask, state, *args): # pylint: disable=argume self._start_index_m(self._dropout(F.transpose(m, axes=(1, 0, 2)))) - m2, _ = self._end_index_lstm(m, state) + m2 = self._end_index_lstm(m) end_index_dense_output = self._end_index_g(self._dropout(x)) + \ self._end_index_m(self._dropout(F.transpose(m2, axes=(1, 0, 2)))) @@ -320,7 +256,6 @@ def __init__(self, word_vocab, char_vocab, options, prefix=None, params=None): options.highway_num_layers, options.embedding_size, dropout=options.dropout, - precision=options.precision, prefix="context_embedding") self.similarity_function = LinearSimilarity(array_1_dim=6 * options.embedding_size, @@ -337,32 +272,24 @@ def __init__(self, word_vocab, char_vocab, options, prefix=None, params=None): self.attention_layer = BidirectionalAttentionFlow(options.batch_size, options.ctx_max_len, options.q_max_len, - 2 * options.embedding_size, - options.precision) + 2 * options.embedding_size) self.modeling_layer = BiDAFModelingLayer(options.batch_size, input_dim=options.embedding_size, nlayers=options.modeling_num_layers, - dropout=options.dropout, - precision=options.precision) + dropout=options.dropout) self.output_layer = BiDAFOutputLayer(options.batch_size, span_start_input_dim=options.embedding_size, nlayers=options.output_num_layers, - dropout=options.dropout, - precision=options.precision) + dropout=options.dropout) def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False, force_reinit=False): super(BiDAFModel, self).initialize(init, ctx, verbose, force_reinit) self.ctx_embedding.init_embeddings(not self._options.train_unk_token) - def hybrid_forward(self, F, qw, cw, qc, cc, - ctx_embedding_states=None, - q_embedding_states=None, - modeling_layer_states=None, - output_layer_states=None, - *args): - ctx_embedding_output = self.ctx_embedding(cw, cc, ctx_embedding_states) - q_embedding_output = self.ctx_embedding(qw, qc, q_embedding_states) + def hybrid_forward(self, F, qw, cw, qc, cc, *args): + ctx_embedding_output = self.ctx_embedding(cw, cc) + q_embedding_output = self.ctx_embedding(qw, qc) # attention layer expect batch_size x seq_length x channels ctx_embedding_output = F.transpose(ctx_embedding_output, axes=(1, 0, 2)) @@ -388,9 +315,8 @@ def hybrid_forward(self, F, qw, cw, qc, cc, attention_layer_output = F.transpose(attention_layer_output, axes=(1, 0, 2)) # modeling layer expects seq_length x batch_size x channels - modeling_layer_output = self.modeling_layer(attention_layer_output, modeling_layer_states) + modeling_layer_output = self.modeling_layer(attention_layer_output) - output = self.output_layer(attention_layer_output, modeling_layer_output, ctx_mask, - output_layer_states) + output = self.output_layer(attention_layer_output, modeling_layer_output, ctx_mask) return output diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index 7f8476be81..2251ef4151 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -209,9 +209,6 @@ def run_training(net, dataloader, ctx, options): hyperparameters = {'learning_rate': options.lr} - if options.precision == 'float16' and options.use_multiprecision_in_optimizer: - hyperparameters["multi_precision"] = True - if options.rho: hyperparameters["rho"] = options.rho @@ -227,13 +224,12 @@ def run_training(net, dataloader, ctx, options): ema = None train_start = time() - avg_loss = mx.nd.zeros((1,), ctx=ctx[0], dtype=options.precision) + avg_loss = mx.nd.zeros((1,), ctx=ctx[0]) iteration = 1 max_dev_exact = -1 max_dev_f1 = -1 max_iteration = -1 early_stop_tries = 0 - records_per_epoch_count = 0 print("Starting training...") @@ -242,11 +238,6 @@ def run_training(net, dataloader, ctx, options): avg_loss *= 0 # Zero average loss of each epoch records_per_epoch_count = 0 - ctx_embedding_begin_state_list = net.ctx_embedding.begin_state(ctx) - q_embedding_begin_state_list = net.ctx_embedding.begin_state(ctx) - m_layer_begin_state_list = net.modeling_layer.begin_state(ctx) - o_layer_begin_state_list = net.output_layer.begin_state(ctx) - for i, (data, label) in enumerate(dataloader): # start timing for the first batch of epoch if i == 0: @@ -255,13 +246,6 @@ def run_training(net, dataloader, ctx, options): record_index, q_words, ctx_words, q_chars, ctx_chars = data records_per_epoch_count += record_index.shape[0] - record_index = record_index.astype(options.precision) - q_words = q_words.astype(options.precision) - ctx_words = ctx_words.astype(options.precision) - q_chars = q_chars.astype(options.precision) - ctx_chars = ctx_chars.astype(options.precision) - label = label.astype(options.precision) - record_index = gluon.utils.split_and_load(record_index, ctx, even_split=False) q_words = gluon.utils.split_and_load(q_words, ctx, even_split=False) ctx_words = gluon.utils.split_and_load(ctx_words, ctx, even_split=False) @@ -271,20 +255,10 @@ def run_training(net, dataloader, ctx, options): losses = [] - for ri, qw, cw, qc, cc, l, ctx_embedding_begin_state, \ - q_embedding_begin_state, m_layer_begin_state, \ - o_layer_begin_state in zip(record_index, q_words, ctx_words, - q_chars, ctx_chars, label, - ctx_embedding_begin_state_list, - q_embedding_begin_state_list, - m_layer_begin_state_list, - o_layer_begin_state_list): + for ri, qw, cw, qc, cc, l in zip(record_index, q_words, ctx_words, + q_chars, ctx_chars, label): with autograd.record(): - begin, end = net(qw, cw, qc, cc, - ctx_embedding_begin_state, - q_embedding_begin_state, - m_layer_begin_state, - o_layer_begin_state) + begin, end = net(qw, cw, qc, cc) begin_end = l.split(axis=1, num_outputs=2, squeeze_axis=1) loss = loss_function(begin, begin_end[0]) + \ loss_function(end, begin_end[1]) @@ -626,7 +600,6 @@ def run_training_mode(options): ctx = get_context(options) net = BiDAFModel(word_vocab, char_vocab, options, prefix="bidaf") - net.cast(options.precision) net.initialize(init.Xavier(), ctx=ctx) net.hybridize(static_alloc=True) @@ -764,14 +737,10 @@ def get_args(): help='Coma-separated ids of the gpu to use. Empty means to use cpu.') parser.add_argument('--train_unk_token', default=False, action='store_true', help='Should train unknown token of embedding') - parser.add_argument('--precision', type=str, default='float32', choices=['float16', 'float32'], - help='Use float16 or float32 precision') parser.add_argument('--filter_long_context', default=True, action='store_false', help='Filter contexts if the answer is after ctx_max_len') parser.add_argument('--save_prediction_path', type=str, default='', help='Path to save predictions') - parser.add_argument('--use_multiprecision_in_optimizer', default=True, action='store_false', - help='When using float16, shall optimizer use multiprecision.') parser.add_argument('--use_exponential_moving_average', default=True, action='store_false', help='Should averaged copy of parameters been stored and used ' 'during evaluation.') diff --git a/scripts/tests/test_question_answering.py b/scripts/tests/test_question_answering.py index f2df79131d..31364af021 100644 --- a/scripts/tests/test_question_answering.py +++ b/scripts/tests/test_question_answering.py @@ -29,13 +29,14 @@ import gluonnlp as nlp from gluonnlp.data import SQuAD +from scripts.question_answering.attention_flow import AttentionFlow from scripts.question_answering.bidaf import BidirectionalAttentionFlow from scripts.question_answering.data_processing import SQuADTransform, VocabProvider from scripts.question_answering.exponential_moving_average import PolyakAveraging from scripts.question_answering.performance_evaluator import PerformanceEvaluator from scripts.question_answering.question_answering import * from scripts.question_answering.question_id_mapper import QuestionIdMapper -from scripts.question_answering.similarity_function import DotProductSimilarity +from scripts.question_answering.similarity_function import DotProductSimilarity, LinearSimilarity from scripts.question_answering.tokenizer import BiDAFTokenizer from scripts.question_answering.train_question_answering import get_record_per_answer_span @@ -49,7 +50,7 @@ @pytest.mark.serial def test_transform_to_nd_array(): dataset = SQuAD(segment='dev', root='tests/data/squad') - vocab_provider = VocabProvider(dataset, get_args(batch_size)) + vocab_provider = VocabProvider([dataset], get_args(batch_size)) transformer = SQuADTransform(vocab_provider, question_max_length, context_max_length, max_chars_per_word, embedding_size) record = dataset[0] @@ -62,7 +63,7 @@ def test_transform_to_nd_array(): @pytest.mark.serial def test_data_loader_able_to_read(): dataset = SQuAD(segment='dev', root='tests/data/squad') - vocab_provider = VocabProvider(dataset, get_args(batch_size)) + vocab_provider = VocabProvider([dataset], get_args(batch_size)) transformer = SQuADTransform(vocab_provider, question_max_length, context_max_length, max_chars_per_word, embedding_size) record = dataset[0] @@ -85,7 +86,7 @@ def test_data_loader_able_to_read(): @pytest.mark.serial def test_load_vocabs(): dataset = SQuAD(segment='dev', root='tests/data/squad') - vocab_provider = VocabProvider(dataset, get_args(batch_size)) + vocab_provider = VocabProvider([dataset], get_args(batch_size)) assert vocab_provider.get_word_level_vocab(embedding_size) is not None assert vocab_provider.get_char_level_vocab() is not None @@ -93,7 +94,7 @@ def test_load_vocabs(): def test_bidaf_embedding(): dataset = SQuAD(segment='dev', root='tests/data/squad') - vocab_provider = VocabProvider(dataset, get_args(batch_size)) + vocab_provider = VocabProvider([dataset], get_args(batch_size)) transformer = SQuADTransform(vocab_provider, question_max_length, context_max_length, max_chars_per_word, embedding_size) @@ -111,26 +112,17 @@ def test_bidaf_embedding(): embedding = BiDAFEmbedding(word_vocab=word_vocab, char_vocab=char_vocab, batch_size=batch_size, - max_seq_len=question_max_length, - precision="float16") - embedding.cast("float16") - embedding.initialize(init.Xavier(magnitude=2.24)) + max_seq_len=question_max_length) + embedding.initialize(init.Xavier(magnitude=2.24), ctx=mx.cpu_pinned()) embedding.hybridize(static_alloc=True) - state = embedding.begin_state(mx.cpu()) - trainer = Trainer(embedding.collect_params(), "sgd", {"learning_rate": 0.1, - "multi_precision": True}) + trainer = Trainer(embedding.collect_params(), "sgd", {"learning_rate": 0.1}) for i, (data, label) in enumerate(dataloader): with autograd.record(): record_index, q_words, ctx_words, q_chars, ctx_chars = data - q_words = q_words.astype("float16") - ctx_words = ctx_words.astype("float16") - q_chars = q_chars.astype("float16") - ctx_chars = ctx_chars.astype("float16") - label = label.astype("float16") # passing only question_words_nd and question_chars_nd batch - out = embedding(q_words, q_chars, state) + out = embedding(q_words, q_chars) assert out is not None out.backward() @@ -139,28 +131,35 @@ def test_bidaf_embedding(): def test_attention_layer(): - ctx_fake_data = nd.random.uniform(shape=(batch_size, context_max_length, 2 * embedding_size), - dtype="float16") + ctx_fake_data = nd.random.uniform(shape=(batch_size, context_max_length, 2 * embedding_size)) - q_fake_data = nd.random.uniform(shape=(batch_size, question_max_length, 2 * embedding_size), - dtype="float16") + q_fake_data = nd.random.uniform(shape=(batch_size, question_max_length, 2 * embedding_size)) - ctx_fake_mask = nd.ones(shape=(batch_size, context_max_length), dtype="float16") - q_fake_mask = nd.ones(shape=(batch_size, question_max_length), dtype="float16") + ctx_fake_mask = nd.ones(shape=(batch_size, context_max_length)) + q_fake_mask = nd.ones(shape=(batch_size, question_max_length)) - layer = BidirectionalAttentionFlow(DotProductSimilarity(), - batch_size, + matrix_attention = AttentionFlow(LinearSimilarity(array_1_dim=6 * embedding_size, + array_2_dim=1, + combination="x,y,x*y"), + batch_size, + context_max_length, + question_max_length, + 2 * embedding_size) + + layer = BidirectionalAttentionFlow(batch_size, context_max_length, question_max_length, - 2 * embedding_size, - "float16") + 2 * embedding_size) - layer.cast("float16") + matrix_attention.initialize() layer.initialize() - layer.hybridize(static_alloc=True) with autograd.record(): - output = layer(ctx_fake_data, q_fake_data, q_fake_mask, ctx_fake_mask) + passage_question_similarity = matrix_attention(ctx_fake_data, q_fake_data).reshape( + shape=(batch_size, context_max_length, question_max_length)) + + output = layer(passage_question_similarity, ctx_fake_data, q_fake_data, + q_fake_mask, ctx_fake_mask) assert output.shape == (batch_size, context_max_length, 8 * embedding_size) @@ -169,22 +168,18 @@ def test_modeling_layer(): # The modeling layer receive input in a shape of batch_size x T x 8d # T is the sequence length of context which is context_max_length # d is the size of embedding, which is embedding_size - fake_data = nd.random.uniform(shape=(batch_size, context_max_length, 8 * embedding_size), - dtype="float16") + fake_data = nd.random.uniform(shape=(batch_size, context_max_length, 8 * embedding_size)) # We assume that attention is already return data in TNC format attention_output = nd.transpose(fake_data, axes=(1, 0, 2)) - layer = BiDAFModelingLayer(batch_size, precision="float16") - layer.cast("float16") + layer = BiDAFModelingLayer(batch_size) layer.initialize() layer.hybridize(static_alloc=True) - state = layer.begin_state(mx.cpu()) - trainer = Trainer(layer.collect_params(), "sgd", {"learning_rate": "0.1", - "multi_precision": True}) + trainer = Trainer(layer.collect_params(), "sgd", {"learning_rate": "0.1"}) with autograd.record(): - output = layer(attention_output, state) + output = layer(attention_output) output.backward() # According to the paper, the output should be 2d x T @@ -197,36 +192,30 @@ def test_output_layer(): # (batch_size, context_max_length, 8 * embedding_size) # The modeling layer returns data in TNC format - modeling_output = nd.random.uniform(shape=(context_max_length, batch_size, 2 * embedding_size), - dtype="float16") + modeling_output = nd.random.uniform(shape=(context_max_length, batch_size, 2 * embedding_size)) # The layer assumes that attention is already return data in TNC format - attention_output = nd.random.uniform(shape=(context_max_length, batch_size, 8 * embedding_size), - dtype="float16") + attention_output = nd.random.uniform(shape=(context_max_length, batch_size, 8 * embedding_size)) + ctx_mask = nd.ones(shape=(batch_size, context_max_length)) - layer = BiDAFOutputLayer(batch_size, precision="float16") - layer.cast("float16") + layer = BiDAFOutputLayer(batch_size) # The model doesn't need to know the hidden states, so I don't hold variables for the states layer.initialize() layer.hybridize(static_alloc=True) - state = layer.begin_state(mx.cpu()) - - trainer = Trainer(layer.collect_params(), "sgd", {"learning_rate": 0.1, - "multi_precision": True}) with autograd.record(): - output = layer(attention_output, modeling_output, state) + output = layer(attention_output, modeling_output, ctx_mask) - output.backward() # We expect final numbers as batch_size x 2 (first start index, second end index) - assert output.shape == (batch_size, 2, 400) + assert output[0].shape == (batch_size, 400) and \ + output[1].shape == (batch_size, 400) def test_bidaf_model(): options = get_args(batch_size) - ctx = [mx.cpu(0), mx.cpu(1)] + ctx = [mx.cpu(0)] dataset = SQuAD(segment='dev', root='tests/data/squad') - vocab_provider = VocabProvider(dataset, options) + vocab_provider = VocabProvider([dataset], options) transformer = SQuADTransform(vocab_provider, question_max_length, context_max_length, max_chars_per_word, embedding_size) @@ -235,7 +224,7 @@ def test_bidaf_model(): if i < options.batch_size * len(ctx)]) # need to remove question id before feeding the data to data loader - loadable_data, dataloader = get_record_per_answer_span(processed_dataset, options) + train_dataset, train_dataloader = get_record_per_answer_span(processed_dataset, options) word_vocab = vocab_provider.get_word_level_vocab(embedding_size) word_vocab.set_embedding(nlp.embedding.create('glove', source='glove.6B.100d')) @@ -245,26 +234,14 @@ def test_bidaf_model(): char_vocab=char_vocab, options=options) - net.cast("float16") net.initialize(init.Xavier(magnitude=2.24)) net.hybridize(static_alloc=True) - ctx_embedding_begin_state_list = net.ctx_embedding.begin_state(ctx) - q_embedding_begin_state_list = net.ctx_embedding.begin_state(ctx) - m_layer_begin_state_list = net.modeling_layer.begin_state(ctx) - o_layer_begin_state_list = net.output_layer.begin_state(ctx) - loss_function = SoftmaxCrossEntropyLoss() - trainer = Trainer(net.collect_params(), "adadelta", {"learning_rate": 0.5, - "multi_precision": True}) + trainer = Trainer(net.collect_params(), "adadelta", {"learning_rate": 0.5}) - for i, (data, label) in enumerate(dataloader): + for i, (data, label) in enumerate(train_dataloader): record_index, q_words, ctx_words, q_chars, ctx_chars = data - q_words = q_words.astype("float16") - ctx_words = ctx_words.astype("float16") - q_chars = q_chars.astype("float16") - ctx_chars = ctx_chars.astype("float16") - label = label.astype("float16") record_index = gluon.utils.split_and_load(record_index, ctx, even_split=False) q_words = gluon.utils.split_and_load(q_words, ctx, even_split=False) @@ -275,21 +252,13 @@ def test_bidaf_model(): losses = [] - for ri, qw, cw, qc, cc, l, ctx_embedding_begin_state, \ - q_embedding_begin_state, m_layer_begin_state, \ - o_layer_begin_state in zip(record_index, q_words, ctx_words, - q_chars, ctx_chars, label, - ctx_embedding_begin_state_list, - q_embedding_begin_state_list, - m_layer_begin_state_list, - o_layer_begin_state_list): + for ri, qw, cw, qc, cc, l in zip(record_index, q_words, ctx_words, + q_chars, ctx_chars, label): with autograd.record(): - begin, end = net(qw, cw, qc, cc, - ctx_embedding_begin_state, - q_embedding_begin_state, - m_layer_begin_state, - o_layer_begin_state) - loss = loss_function(begin, end, l) + begin, end = net(qw, cw, qc, cc) + begin_end = l.split(axis=1, num_outputs=2, squeeze_axis=1) + loss = loss_function(begin, begin_end[0]) + \ + loss_function(end, begin_end[1]) losses.append(loss) for loss in losses: @@ -298,40 +267,6 @@ def test_bidaf_model(): trainer.step(options.batch_size) break - nd.waitall() - - -def test_performance_evaluation(): - options = get_args(batch_size) - - train_dataset = SQuAD(segment='train') - vocab_provider = VocabProvider(train_dataset, options) - - dataset = SQuAD(segment='dev') - mapper = QuestionIdMapper(dataset) - - transformer = SQuADTransform(vocab_provider, question_max_length, - context_max_length, max_chars_per_word, embedding_size) - - # for performance reason, process only batch_size # of records - transformed_dataset = SimpleDataset([transformer(*record) for i, record in enumerate(dataset) - if i < options.batch_size]) - - word_vocab = vocab_provider.get_word_level_vocab(embedding_size) - word_vocab.set_embedding(nlp.embedding.create('glove', source='glove.6B.100d')) - char_vocab = vocab_provider.get_char_level_vocab() - model_path = os.path.join(options.save_dir, 'epoch{:d}.params'.format(int(options.epochs) - 1)) - - ctx = [mx.cpu()] - evaluator = PerformanceEvaluator(BiDAFTokenizer(), transformed_dataset, - dataset._read_data(), mapper) - net = BiDAFModel(word_vocab, char_vocab, options, prefix="bidaf") - net.hybridize(static_alloc=True) - net.load_parameters(model_path, ctx=ctx) - - result = evaluator.evaluate_performance(net, ctx, options) - print("Evaluation results on dev dataset: {}".format(result)) - def test_get_answer_spans_exact_match(): tokenizer = BiDAFTokenizer() @@ -360,14 +295,14 @@ def test_get_answer_spans_partial_match(): result = SQuADTransform._get_answer_spans(context, context_tokens, [answer], [answer_start_index]) - assert result == [(16, 17)] + assert result == [(15, 16)] def test_get_answer_spans_unicode(): tokenizer = BiDAFTokenizer() context = "Back in Warsaw that year, Chopin heard Niccolò Paganini play" - context_tokens = tokenizer(context) + context_tokens = tokenizer(context, lower_case=True) answer_start_index = 39 answer = "Niccolò Paganini" @@ -382,7 +317,7 @@ def test_get_answer_spans_after_comma(): tokenizer = BiDAFTokenizer() context = "Chopin's successes as a composer and performer opened the door to western Europe for him, and on 2 November 1830, he set out," - context_tokens = tokenizer(context) + context_tokens = tokenizer(context, lower_case=True) answer_start_index = 108 answer = "1830" @@ -390,7 +325,7 @@ def test_get_answer_spans_after_comma(): result = SQuADTransform._get_answer_spans(context, context_tokens, [answer], [answer_start_index]) - assert result == [(23, 23)] + assert result == [(22, 22)] def test_get_answer_spans_after_quotes(): @@ -412,9 +347,9 @@ def test_get_answer_spans_after_quotes(): def test_get_char_indices(): context = "to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary." tokenizer = BiDAFTokenizer() - context_tokens = tokenizer(context) + context_tokens = tokenizer(context, lower_case=True) - result = SQuADTransform._get_char_indices(context, context_tokens) + result = SQuADTransform.get_char_indices(context, context_tokens) assert len(result) == len(context_tokens) @@ -422,7 +357,7 @@ def test_tokenizer_split_new_lines(): context = "that are of equal energy\u2014i.e., degenerate\u2014is a configuration termed a spin triplet state. Hence, the ground state of the O\n2 molecule is referred to as triplet oxygen" tokenizer = BiDAFTokenizer() - context_tokens = tokenizer(context) + context_tokens = tokenizer(context, lower_case=True) assert len(context_tokens) == 35 @@ -431,7 +366,7 @@ def test_polyak_averaging(): net = nn.HybridSequential() net.add(nn.Dense(5), nn.Dense(3), nn.Dense(2)) net.initialize(init.Xavier()) - # net.hybridize() + net.hybridize() ema = None loss_fn = SoftmaxCrossEntropyLoss() @@ -452,6 +387,8 @@ def test_polyak_averaging(): trainer.step(5) ema.update() + assert ema.get_params() is not None + def get_args(batch_size): options = SimpleNamespace() options.gpu = None @@ -466,9 +403,11 @@ def get_args(batch_size): options.ctx_max_len = context_max_length options.q_max_len = question_max_length options.word_max_len = max_chars_per_word - options.precision = "float32" options.epochs = 12 options.save_dir = "output/" options.filter_long_context = False + options.word_vocab_path = "" + options.char_vocab_path = "" + options.train_unk_token = False return options From 1c15ccaf10d1ff44c0952c862d64396fdafa245d Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Thu, 15 Nov 2018 14:44:16 -0800 Subject: [PATCH 39/43] Code review changes --- scripts/question_answering/bidaf.py | 1 - scripts/question_answering/data_processing.py | 3 +- .../exponential_moving_average.py | 74 ------------------- .../performance_evaluator.py | 6 +- .../question_answering/question_answering.py | 42 +++++------ .../train_question_answering.py | 20 ++--- scripts/question_answering/utils.py | 56 +++++++++++++- scripts/tests/test_question_answering.py | 6 +- 8 files changed, 89 insertions(+), 119 deletions(-) delete mode 100644 scripts/question_answering/exponential_moving_average.py diff --git a/scripts/question_answering/bidaf.py b/scripts/question_answering/bidaf.py index b06d7ac4eb..a95cf390ed 100644 --- a/scripts/question_answering/bidaf.py +++ b/scripts/question_answering/bidaf.py @@ -20,7 +20,6 @@ from mxnet import gluon import numpy as np -from .attention_flow import AttentionFlow from .utils import last_dim_softmax, weighted_sum, replace_masked_values, masked_softmax diff --git a/scripts/question_answering/data_processing.py b/scripts/question_answering/data_processing.py index 0b94cc8872..2a547db6da 100644 --- a/scripts/question_answering/data_processing.py +++ b/scripts/question_answering/data_processing.py @@ -19,13 +19,12 @@ # pylint: disable= """SQuAD data preprocessing.""" -import logging import pickle from os.path import isfile import gluonnlp as nlp -from scripts.question_answering.tokenizer import BiDAFTokenizer +from .tokenizer import BiDAFTokenizer __all__ = ['SQuADTransform', 'VocabProvider'] diff --git a/scripts/question_answering/exponential_moving_average.py b/scripts/question_answering/exponential_moving_average.py deleted file mode 100644 index 651ddc5d73..0000000000 --- a/scripts/question_answering/exponential_moving_average.py +++ /dev/null @@ -1,74 +0,0 @@ -# coding: utf-8 - -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# pylint: disable= -"""Exponential Moving Average""" -import mxnet as mx -from mxnet import gluon - - -class PolyakAveraging: - """Class to do Polyak averaging based on this paper - http://www.meyn.ece.ufl.edu/archive/spm_files/Courses/ECE555-2011/555media/poljud92.pdf""" - def __init__(self, params, decay): - self._params = params - self._decay = decay - - self._polyak_params_dict = gluon.ParameterDict() - - for param in self._params.values(): - polyak_param = self._polyak_params_dict.get(param.name, shape=param.shape) - polyak_param.initialize(mx.init.Constant(self._param_data_to_cpu(param)), ctx=mx.cpu()) - - def update(self): - """ - Updates currently held saved parameters with current state of network. - - All calculations for this average occur on the cpu context. - """ - for param in self._params.values(): - polyak_param = self._polyak_params_dict.get(param.name) - polyak_param.set_data( - (1 - self._decay) * self._param_data_to_cpu(param) + - self._decay * polyak_param.data(mx.cpu())) - - def get_params(self): - """ Provides averaged parameters - - Returns - ------- - gluon.ParameterDict - Averaged parameters - """ - return self._polyak_params_dict - - def _param_data_to_cpu(self, param): - """Returns a copy (on CPU context) of the data held in some context of given parameter. - - Parameters - ---------- - param: gluon.Parameter - Parameter's whose data needs to be copied. - - Returns - ------- - NDArray - Copy of data on CPU context. - """ - return param.list_data()[0].copyto(mx.cpu()) diff --git a/scripts/question_answering/performance_evaluator.py b/scripts/question_answering/performance_evaluator.py index afe4840f8d..7e86f44b8e 100644 --- a/scripts/question_answering/performance_evaluator.py +++ b/scripts/question_answering/performance_evaluator.py @@ -22,9 +22,9 @@ from mxnet import nd, gluon, cpu from mxnet.gluon.data import DataLoader, ArrayDataset -from scripts.question_answering.data_processing import SQuADTransform -from scripts.question_answering.official_squad_eval_script import evaluate -from scripts.question_answering.utils import extend_to_batch_size +from .data_processing import SQuADTransform +from .official_squad_eval_script import evaluate +from .utils import extend_to_batch_size class PerformanceEvaluator: diff --git a/scripts/question_answering/question_answering.py b/scripts/question_answering/question_answering.py index 5c499abe89..c3b37eb089 100644 --- a/scripts/question_answering/question_answering.py +++ b/scripts/question_answering/question_answering.py @@ -18,11 +18,10 @@ # under the License. """BiDAF model blocks""" -from scripts.question_answering.attention_flow import AttentionFlow -from scripts.question_answering.bidaf import BidirectionalAttentionFlow -from scripts.question_answering.similarity_function import DotProductSimilarity, CosineSimilarity, \ - LinearSimilarity -from scripts.question_answering.utils import get_very_negative_number +from .attention_flow import AttentionFlow +from .bidaf import BidirectionalAttentionFlow +from .similarity_function import LinearSimilarity +from .utils import get_very_negative_number __all__ = ['BiDAFEmbedding', 'BiDAFModelingLayer', 'BiDAFOutputLayer', 'BiDAFModel'] @@ -97,7 +96,6 @@ def hybrid_forward(self, F, w, c, *args): def convolute(token_of_all_batches, _): return self._char_conv_embedding(token_of_all_batches), [] - char_embedded, _ = F.contrib.foreach(convolute, char_level_data, []) # Transpose to TNC, to join with character embedding @@ -140,14 +138,14 @@ class BiDAFModelingLayer(HybridBlock): Shared Parameters for this `Block`. """ def __init__(self, batch_size, input_dim=100, nlayers=2, biflag=True, - dropout=0.2, prefix=None, params=None): + dropout=0.2, input_size=800, prefix=None, params=None): super(BiDAFModelingLayer, self).__init__(prefix=prefix, params=params) self._batch_size = batch_size with self.name_scope(): self._modeling_layer = LSTM(hidden_size=input_dim, num_layers=nlayers, dropout=dropout, - bidirectional=biflag, input_size=800) + bidirectional=biflag, input_size=input_size) def hybrid_forward(self, F, x, *args): out = self._modeling_layer(x) @@ -192,30 +190,30 @@ def __init__(self, batch_size, span_start_input_dim=100, nlayers=1, biflag=True, with self.name_scope(): self._dropout = nn.Dropout(rate=dropout) - self._start_index_g = nn.Dense(units=1, in_units=8 * span_start_input_dim, - flatten=False) - self._start_index_m = nn.Dense(units=1, in_units=2 * span_start_input_dim, - flatten=False) + self._start_index_combined = nn.Dense(units=1, in_units=8 * span_start_input_dim, + flatten=False) + self._start_index_model = nn.Dense(units=1, in_units=2 * span_start_input_dim, + flatten=False) self._end_index_lstm = LSTM(hidden_size=span_start_input_dim, num_layers=nlayers, dropout=dropout, bidirectional=biflag, input_size=2 * span_start_input_dim) - self._end_index_g = nn.Dense(units=1, in_units=8 * span_start_input_dim, - flatten=False) - self._end_index_m = nn.Dense(units=1, in_units=2 * span_start_input_dim, - flatten=False) + self._end_index_combined = nn.Dense(units=1, in_units=8 * span_start_input_dim, + flatten=False) + self._end_index_model = nn.Dense(units=1, in_units=2 * span_start_input_dim, + flatten=False) def hybrid_forward(self, F, x, m, mask, *args): # pylint: disable=arguments-differ # setting batch size as the first dimension x = F.transpose(x, axes=(1, 0, 2)) - start_index_dense_output = self._start_index_g(self._dropout(x)) + \ - self._start_index_m(self._dropout(F.transpose(m, - axes=(1, 0, 2)))) + start_index_dense_output = self._start_index_combined(self._dropout(x)) + \ + self._start_index_model(self._dropout( + F.transpose(m, axes=(1, 0, 2)))) m2 = self._end_index_lstm(m) - end_index_dense_output = self._end_index_g(self._dropout(x)) + \ - self._end_index_m(self._dropout(F.transpose(m2, - axes=(1, 0, 2)))) + end_index_dense_output = self._end_index_combined(self._dropout(x)) + \ + self._end_index_model(self._dropout(F.transpose(m2, + axes=(1, 0, 2)))) start_index_dense_output = F.squeeze(start_index_dense_output) start_index_dense_output_masked = start_index_dense_output + ((1 - mask) * diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index 2251ef4151..c7b24ff760 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -37,13 +37,13 @@ from gluonnlp.data import SQuAD -from scripts.question_answering.data_processing import VocabProvider, SQuADTransform -from scripts.question_answering.exponential_moving_average import PolyakAveraging -from scripts.question_answering.performance_evaluator import PerformanceEvaluator -from scripts.question_answering.question_answering import * -from scripts.question_answering.question_id_mapper import QuestionIdMapper -from scripts.question_answering.tokenizer import BiDAFTokenizer -from scripts.question_answering.utils import logging_config +from .data_processing import VocabProvider, SQuADTransform +from .utils import PolyakAveraging +from .performance_evaluator import PerformanceEvaluator +from .question_answering import * +from .question_id_mapper import QuestionIdMapper +from .tokenizer import BiDAFTokenizer +from .utils import logging_config def transform_dataset(dataset, vocab_provider, options, enable_filtering=False): @@ -288,7 +288,7 @@ def run_training(net, dataloader, ctx, options): if options.grad_req_add_mode == 0 else options.grad_req_add_mode if options.lr_warmup_steps: - trainer.set_learning_rate(get_learning_rate_per_iteration(iteration, options)) + trainer.set_learning_rate(warm_up_steps(iteration, options)) if options.clip or options.train_unk_token: trainer.allreduce_grads() @@ -378,7 +378,7 @@ def run_training(net, dataloader, ctx, options): print("Training time {:6.2f} seconds".format(time() - train_start)) -def get_learning_rate_per_iteration(iteration, options): +def warm_up_steps(iteration, options): """Returns learning rate based on current iteration. Used to implement learning rate warm up technique @@ -601,7 +601,7 @@ def run_training_mode(options): net = BiDAFModel(word_vocab, char_vocab, options, prefix="bidaf") net.initialize(init.Xavier(), ctx=ctx) - net.hybridize(static_alloc=True) +# net.hybridize(static_alloc=True) if options.grad_req_add_mode: net.collect_params().setattr('grad_req', 'add') diff --git a/scripts/question_answering/utils.py b/scripts/question_answering/utils.py index c4413fd44b..7d0fc5a92b 100644 --- a/scripts/question_answering/utils.py +++ b/scripts/question_answering/utils.py @@ -16,12 +16,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import argparse import os import inspect import logging -from mxnet import nd +import mxnet as mx +from mxnet import nd, gluon def logging_config(folder=None, name=None, level=logging.DEBUG, console_level=logging.INFO, @@ -74,7 +74,6 @@ def logging_config(folder=None, name=None, level=logging.DEBUG, console_level=lo return folder - def get_very_negative_number(): return -1e30 @@ -325,3 +324,54 @@ def replace_masked_values(F, tensor, mask, replace_with): one_minus_mask = 1.0 - mask values_to_add = replace_with * one_minus_mask return F.broadcast_add(F.broadcast_mul(tensor, mask), values_to_add) + + +class PolyakAveraging: + """Class to do Polyak averaging based on this paper + http://www.meyn.ece.ufl.edu/archive/spm_files/Courses/ECE555-2011/555media/poljud92.pdf""" + def __init__(self, params, decay): + self._params = params + self._decay = decay + + self._polyak_params_dict = gluon.ParameterDict() + + for param in self._params.values(): + polyak_param = self._polyak_params_dict.get(param.name, shape=param.shape) + polyak_param.initialize(mx.init.Constant(self._param_data_to_cpu(param)), ctx=mx.cpu()) + + def update(self): + """ + Updates currently held saved parameters with current state of network. + + All calculations for this average occur on the cpu context. + """ + for param in self._params.values(): + polyak_param = self._polyak_params_dict.get(param.name) + polyak_param.set_data( + (1 - self._decay) * self._param_data_to_cpu(param) + + self._decay * polyak_param.data(mx.cpu())) + + def get_params(self): + """ Provides averaged parameters + + Returns + ------- + gluon.ParameterDict + Averaged parameters + """ + return self._polyak_params_dict + + def _param_data_to_cpu(self, param): + """Returns a copy (on CPU context) of the data held in some context of given parameter. + + Parameters + ---------- + param: gluon.Parameter + Parameter's whose data needs to be copied. + + Returns + ------- + NDArray + Copy of data on CPU context. + """ + return param.list_data()[0].copyto(mx.cpu()) diff --git a/scripts/tests/test_question_answering.py b/scripts/tests/test_question_answering.py index 0c3382b9ac..c74ce8dab1 100644 --- a/scripts/tests/test_question_answering.py +++ b/scripts/tests/test_question_answering.py @@ -33,11 +33,9 @@ from scripts.question_answering.attention_flow import AttentionFlow from scripts.question_answering.bidaf import BidirectionalAttentionFlow from scripts.question_answering.data_processing import SQuADTransform, VocabProvider -from scripts.question_answering.exponential_moving_average import PolyakAveraging -from scripts.question_answering.performance_evaluator import PerformanceEvaluator +from scripts.question_answering.utils import PolyakAveraging from scripts.question_answering.question_answering import * -from scripts.question_answering.question_id_mapper import QuestionIdMapper -from scripts.question_answering.similarity_function import DotProductSimilarity, LinearSimilarity +from scripts.question_answering.similarity_function import LinearSimilarity from scripts.question_answering.tokenizer import BiDAFTokenizer from scripts.question_answering.train_question_answering import get_record_per_answer_span From 7303444a1d0cc7cf965d31566e32129468fdd2a5 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Thu, 15 Nov 2018 14:50:05 -0800 Subject: [PATCH 40/43] Optimize imports --- scripts/question_answering/tokenizer.py | 2 +- scripts/question_answering/train_question_answering.py | 9 +++------ scripts/question_answering/utils.py | 3 ++- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/scripts/question_answering/tokenizer.py b/scripts/question_answering/tokenizer.py index 98c18bce57..c6b17a2040 100644 --- a/scripts/question_answering/tokenizer.py +++ b/scripts/question_answering/tokenizer.py @@ -16,8 +16,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import re import nltk +import re class BiDAFTokenizer: diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index c7b24ff760..1e7f3e77b1 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -17,23 +17,20 @@ # specific language governing permissions and limitations # under the License. import argparse -import math - import copy +import math import multiprocessing +import logging import os -from mxnet.gluon.loss import SoftmaxCrossEntropyLoss from os.path import isfile - -import logging import pickle - from time import time import mxnet as mx from mxnet import gluon, init, autograd from mxnet.gluon import Trainer from mxnet.gluon.data import DataLoader, SimpleDataset, ArrayDataset +from mxnet.gluon.loss import SoftmaxCrossEntropyLoss from gluonnlp.data import SQuAD diff --git a/scripts/question_answering/utils.py b/scripts/question_answering/utils.py index 7d0fc5a92b..9af5fc5d41 100644 --- a/scripts/question_answering/utils.py +++ b/scripts/question_answering/utils.py @@ -16,10 +16,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import os import inspect import logging +import os + import mxnet as mx from mxnet import nd, gluon From 1571548b5db9b9d90452833a8dc069efd8fd5f50 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Thu, 15 Nov 2018 17:09:05 -0800 Subject: [PATCH 41/43] Only EMA or original params get saved --- .../question_answering/question_answering.py | 12 ++-- .../train_question_answering.py | 62 +++++-------------- 2 files changed, 21 insertions(+), 53 deletions(-) diff --git a/scripts/question_answering/question_answering.py b/scripts/question_answering/question_answering.py index c3b37eb089..6b9fd64c57 100644 --- a/scripts/question_answering/question_answering.py +++ b/scripts/question_answering/question_answering.py @@ -73,18 +73,16 @@ def __init__(self, batch_size, word_vocab, char_vocab, max_seq_len, bidirectional=True, input_size=2 * embedding_size, dropout=dropout) - def init_embeddings(self, lock_gradients): + def init_embeddings(self, grad_req='null'): """Initialize words embeddings with provided embedding values Parameters ---------- - lock_gradients: bool - Flag to stop parameters from being trained + grad_req: str + How to treat gradients of embedding layer """ self._word_embedding.weight.set_data(self._word_vocab.embedding.idx_to_vec) - - if lock_gradients: - self._word_embedding.collect_params().setattr('grad_req', 'null') + self._word_embedding.collect_params().setattr('grad_req', grad_req) def hybrid_forward(self, F, w, c, *args): word_embedded = self._word_embedding(w) @@ -283,7 +281,7 @@ def __init__(self, word_vocab, char_vocab, options, prefix=None, params=None): def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False, force_reinit=False): super(BiDAFModel, self).initialize(init, ctx, verbose, force_reinit) - self.ctx_embedding.init_embeddings(not self._options.train_unk_token) + self.ctx_embedding.init_embeddings('null' if not self._options.train_unk_token else 'write') def hybrid_forward(self, F, qw, cw, qc, cc, *args): ctx_embedding_output = self.ctx_embedding(cw, cc) diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index 1e7f3e77b1..e1e009b1b7 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -269,7 +269,7 @@ def run_training(net, dataloader, ctx, options): options.exponential_moving_average_weight_decay) if options.resume_training: - path = os.path.join(options.save_dir, 'ema_epoch{:d}.params'.format( + path = os.path.join(options.save_dir, 'epoch{:d}.params'.format( options.resume_training - 1)) ema.get_params().load(path) @@ -358,8 +358,7 @@ def run_training(net, dataloader, ctx, options): .format(e, avg_loss_scalar, options.batch_size, trainer.learning_rate, records_per_epoch_count / epoch_time, epoch_time)) - save_model_parameters(net, e, options) - save_ema_parameters(ema, e, options) + save_model_parameters(net.collect_params() if ema is None else ema.get_params(), e, options) save_trainer_parameters(trainer, e, options) if options.terminate_training_on_reaching_F1_threshold: @@ -458,12 +457,12 @@ def is_fixed_embedding_layer(name): return True if "predefined_embedding_layer" in name else False -def save_model_parameters(net, epoch, options): +def save_model_parameters(params, epoch, options): """Save parameters of the trained model Parameters ---------- - net : `Block` + params : `gluon.ParameterDict` Model with trained parameters epoch : `int` Number of epoch @@ -474,29 +473,7 @@ def save_model_parameters(net, epoch, options): os.mkdir(options.save_dir) save_path = os.path.join(options.save_dir, 'epoch{:d}.params'.format(epoch)) - net.save_parameters(save_path) - - -def save_ema_parameters(ema, epoch, options): - """Save exponentially averaged parameters of the trained model - - Parameters - ---------- - ema : `PolyakAveraging` - Model with trained parameters - epoch : `int` - Number of epoch - options : `Namespace` - Saving arguments - """ - if ema is None: - return - - if not os.path.exists(options.save_dir): - os.mkdir(options.save_dir) - - save_path = os.path.join(options.save_dir, 'ema_epoch{:d}.params'.format(epoch)) - ema.get_params().save(save_path) + params.save(save_path) def save_trainer_parameters(trainer, epoch, options): @@ -617,6 +594,10 @@ def run_evaluate_mode(options, existing_net=None, existing_ema=None): Parameters ---------- + existing_net : `Block` + Trained existing network + existing_net : `PolyakAveraging` + Averaged parameters of the network options : `Namespace` Model evaluation arguments @@ -645,27 +626,16 @@ def run_evaluate_mode(options, existing_net=None, existing_ema=None): net = BiDAFModel(word_vocab, char_vocab, options, prefix="bidaf") - if options.use_exponential_moving_average: - if existing_ema is None: - params_path = os.path.join(options.save_dir, - 'ema_epoch{:d}.params'.format(int(options.epochs) - 1)) - else: - save_ema_parameters(existing_ema, options.epochs, options) - params_path = os.path.join(options.save_dir, - 'ema_epoch{:d}.params'.format(options.epochs)) + if existing_ema is not None: + save_model_parameters(existing_ema.get_params(), options.epochs, options) - net.collect_params().load(params_path, ctx=ctx) - else: - if existing_net is None: - params_path = os.path.join(options.save_dir, - 'epoch{:d}.params'.format(int(options.epochs) - 1)) - else: - save_model_parameters(existing_net, options.epochs, options) - params_path = os.path.join(options.save_dir, - 'epoch{:d}.params'.format(options.epochs)) + elif existing_net is not None: + save_model_parameters(existing_net.collect_params(), options.epochs, options) - net.load_parameters(params_path, ctx=ctx) + params_path = os.path.join(options.save_dir, + 'epoch{:d}.params'.format(int(options.epochs) - 1)) + net.collect_params().load(params_path, ctx=ctx) net.hybridize(static_alloc=True) return evaluator.evaluate_performance(net, ctx, options) From cbbfcb3b30085d0dde63a38761f0e30a04c067ea Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Thu, 15 Nov 2018 17:09:57 -0800 Subject: [PATCH 42/43] Get hybridization back --- scripts/question_answering/train_question_answering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index e1e009b1b7..4935c54893 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -575,7 +575,7 @@ def run_training_mode(options): net = BiDAFModel(word_vocab, char_vocab, options, prefix="bidaf") net.initialize(init.Xavier(), ctx=ctx) -# net.hybridize(static_alloc=True) + net.hybridize(static_alloc=True) if options.grad_req_add_mode: net.collect_params().setattr('grad_req', 'add') From 355116c4b6577264a45742eb079c53f7d79b37e0 Mon Sep 17 00:00:00 2001 From: Sergey Sokolov Date: Fri, 16 Nov 2018 10:57:00 -0800 Subject: [PATCH 43/43] Make pylint happy --- gluonnlp/data/question_answering.py | 9 - scripts/question_answering/attention_flow.py | 1 + scripts/question_answering/bidaf.py | 21 +- scripts/question_answering/data_processing.py | 27 ++- .../official_squad_eval_script.py | 36 +++- .../performance_evaluator.py | 10 +- .../question_answering/question_answering.py | 28 +-- .../question_answering/similarity_function.py | 13 +- scripts/question_answering/tokenizer.py | 13 +- .../train_question_answering.py | 185 ++++++++++-------- scripts/question_answering/utils.py | 13 +- 11 files changed, 210 insertions(+), 146 deletions(-) diff --git a/gluonnlp/data/question_answering.py b/gluonnlp/data/question_answering.py index dafc3fc22c..5016977a16 100644 --- a/gluonnlp/data/question_answering.py +++ b/gluonnlp/data/question_answering.py @@ -108,15 +108,6 @@ def _read_data(self): Question id and list_of_answers also substituted with indices, so it could be later converted into nd.array - Returns - ------- - List[Tuple] - Flatten list of questions - """ - """Read data.json from disk and flats it to the following format: - Entry = (record_index, question_id, question, context, answer_list, answer_start_indices). - Question id and list_of_answers also substituted with indices, so it could be later - converted into nd.array Returns ------- List[Tuple] diff --git a/scripts/question_answering/attention_flow.py b/scripts/question_answering/attention_flow.py index 4d55a5655c..fd7c256df1 100644 --- a/scripts/question_answering/attention_flow.py +++ b/scripts/question_answering/attention_flow.py @@ -17,6 +17,7 @@ # specific language governing permissions and limitations # under the License. +"""Attention Flow Layer""" from mxnet import gluon from .similarity_function import DotProductSimilarity diff --git a/scripts/question_answering/bidaf.py b/scripts/question_answering/bidaf.py index a95cf390ed..fb14a21a9b 100644 --- a/scripts/question_answering/bidaf.py +++ b/scripts/question_answering/bidaf.py @@ -17,6 +17,7 @@ # specific language governing permissions and limitations # under the License. +"""Bidirectional attention flow layer""" from mxnet import gluon import numpy as np @@ -28,8 +29,8 @@ class BidirectionalAttentionFlow(gluon.HybridBlock): This class implements Minjoon Seo's `Bidirectional Attention Flow model `_ for answering reading comprehension questions (ICLR 2017). - """ + def __init__(self, batch_size, passage_length, @@ -44,16 +45,26 @@ def __init__(self, self._encoding_dim = encoding_dim def _get_big_negative_value(self): - return np.finfo(np.float32).min + """Provides maximum negative Float32 value + Returns + ------- + value : float32 + Maximum negative float32 value + """ + return np.finfo(np.float32).min def _get_small_positive_value(self): - return np.finfo(np.float32).eps + """Provides minimal possible Float32 value + Returns + ------- + value : float32 + Minimal float32 value + """ + return np.finfo(np.float32).eps def hybrid_forward(self, F, passage_question_similarity, encoded_passage, encoded_question, question_mask, passage_mask): # pylint: disable=arguments-differ - """ - """ # Shape: (batch_size, passage_length, question_length) passage_question_similarity_shape = (self._batch_size, self._passage_length, self._question_length) diff --git a/scripts/question_answering/data_processing.py b/scripts/question_answering/data_processing.py index 2a547db6da..97d1d48a23 100644 --- a/scripts/question_answering/data_processing.py +++ b/scripts/question_answering/data_processing.py @@ -19,27 +19,25 @@ # pylint: disable= """SQuAD data preprocessing.""" -import pickle - -from os.path import isfile - -import gluonnlp as nlp -from .tokenizer import BiDAFTokenizer __all__ = ['SQuADTransform', 'VocabProvider'] +import pickle +from os.path import isfile import numpy as np - from mxnet import nd +import gluonnlp as nlp from gluonnlp import Vocab, data from gluonnlp.data.batchify import Pad +from .tokenizer import BiDAFTokenizer class SQuADTransform(object): """SQuADTransform class responsible for converting text data into NDArrays that can be later feed into DataProvider """ + def __init__(self, vocab_provider, question_max_length, context_max_length, max_chars_per_word, embedding_size): self._word_vocab = vocab_provider.get_word_level_vocab(embedding_size) @@ -149,14 +147,14 @@ def _get_answer_spans(context, context_tokens, answer_list, answer_start_list): answer_token_indices.append(context_token_index) if len(answer_token_indices) == 0: - print("Warning: Answer {} not found for context {}".format(answer, context)) + print('Warning: Answer {} not found for context {}'.format(answer, context)) else: answer_span = (answer_token_indices[0], answer_token_indices[len(answer_token_indices) - 1]) answer_spans.append(answer_span) if len(answer_spans) == 0: - print("Warning: No answers found for context {}".format(context_tokens)) + print('Warning: No answers found for context {}'.format(context_tokens)) return answer_spans @@ -248,6 +246,7 @@ def _pad_to_max_word_length(item, max_length): class VocabProvider(object): """Provides word level and character level vocabularies """ + def __init__(self, datasets, options, tokenizer=BiDAFTokenizer()): self._datasets = datasets self._options = options @@ -273,7 +272,7 @@ def get_char_level_vocab(self): Character level vocabulary """ if self._options.char_vocab_path and isfile(self._options.char_vocab_path): - return pickle.load(open(self._options.char_vocab_path, "rb")) + return pickle.load(open(self._options.char_vocab_path, 'rb')) all_chars = [] for dataset in self._datasets: @@ -282,7 +281,7 @@ def get_char_level_vocab(self): char_level_vocab = VocabProvider._create_squad_vocab(all_chars) if self._options.char_vocab_path: - pickle.dump(char_level_vocab, open(self._options.char_vocab_path, "wb")) + pickle.dump(char_level_vocab, open(self._options.char_vocab_path, 'wb')) return char_level_vocab @@ -296,7 +295,7 @@ def get_word_level_vocab(self, embedding_size): """ if self._options.word_vocab_path and isfile(self._options.word_vocab_path): - return pickle.load(open(self._options.word_vocab_path, "rb")) + return pickle.load(open(self._options.word_vocab_path, 'rb')) all_words = [] for dataset in self._datasets: @@ -312,10 +311,10 @@ def get_word_level_vocab(self, embedding_size): if (word_level_vocab.embedding.idx_to_vec[i].sum() != 0).asscalar(): count += 1 - print("{}/{} words have embeddings".format(count, len(word_level_vocab))) + print('{}/{} words have embeddings'.format(count, len(word_level_vocab))) if self._options.word_vocab_path: - pickle.dump(word_level_vocab, open(self._options.word_vocab_path, "wb")) + pickle.dump(word_level_vocab, open(self._options.word_vocab_path, 'wb')) return word_level_vocab diff --git a/scripts/question_answering/official_squad_eval_script.py b/scripts/question_answering/official_squad_eval_script.py index e3d01c4d70..b2a8948f1e 100644 --- a/scripts/question_answering/official_squad_eval_script.py +++ b/scripts/question_answering/official_squad_eval_script.py @@ -28,6 +28,21 @@ def f1_score(prediction, ground_truth): + """Calculate F1 score + + Parameters + ---------- + prediction : str + Prediction + + ground_truth : str + Ground truth + + Returns + ------- + F1 : float + F1 score + """ prediction_tokens = normalize_answer(prediction).split() ground_truth_tokens = normalize_answer(ground_truth).split() common = Counter(prediction_tokens) & Counter(ground_truth_tokens) @@ -71,6 +86,21 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): def evaluate(dataset, predictions): + """Evaluate dataset against predictions + + Parameters + ---------- + dataset : `dict` + Dictionary containing JSON of SQuAD 1.1 dataset + + predictions : `dict` + Map of Question id and prediction + + Returns + ------- + result : `dict` + Dictionary with F1 and Exact Match scores + """ f1 = exact_match = total = 0 for article in dataset: for paragraph in article['paragraphs']: @@ -107,7 +137,7 @@ def evaluate(dataset, predictions): print('Evaluation expects v-' + expected_version + ', but got dataset with v-' + dataset_json['version'], file=sys.stderr) - dataset = dataset_json['data'] + dataset_data = dataset_json['data'] with open(args.prediction_file) as prediction_file: - predictions = json.load(prediction_file) - print(json.dumps(evaluate(dataset, predictions))) + predictions_json = json.load(prediction_file) + print(json.dumps(evaluate(dataset_data, predictions_json))) diff --git a/scripts/question_answering/performance_evaluator.py b/scripts/question_answering/performance_evaluator.py index 7e86f44b8e..fdd73d0b6a 100644 --- a/scripts/question_answering/performance_evaluator.py +++ b/scripts/question_answering/performance_evaluator.py @@ -56,7 +56,7 @@ def evaluate_performance(self, net, ctx, options): pred = {} # Allows to ensure that start index is always <= than end index - for c in ctx: + for _ in ctx: answer_mask_matrix = nd.zeros(shape=(1, options.ctx_max_len, options.ctx_max_len), ctx=cpu(0)) for idx in range(options.answer_max_len): @@ -64,14 +64,14 @@ def evaluate_performance(self, net, ctx, options): k=idx, ctx=cpu(0)) eval_dataset = ArrayDataset([(self._mapper.question_id_to_idx[r[1]], r[2], r[3], r[4], r[5]) - for r in self._evaluation_dataset]) + for r in self._evaluation_dataset]) eval_dataloader = DataLoader(eval_dataset, batch_size=len(ctx) * options.batch_size, last_batch='keep', pin_memory=True, num_workers=(multiprocessing.cpu_count() - len(ctx) - 2)) - for i, data in enumerate(eval_dataloader): + for data in eval_dataloader: record_index, q_words, ctx_words, q_chars, ctx_chars = data record_index = extend_to_batch_size(options.batch_size * len(ctx), record_index, -1) @@ -112,9 +112,9 @@ def evaluate_performance(self, net, ctx, options): pred[question_id] = (start, end, self.get_text_result(idx, (start, end))) if options.save_prediction_path: - with open(options.save_prediction_path, "w") as f: + with open(options.save_prediction_path, 'w') as f: for item in pred.items(): - f.write("{}: {}-{} Answer: {}\n".format(item[0], item[1][0], + f.write('{}: {}-{} Answer: {}\n'.format(item[0], item[1][0], item[1][1], item[1][2])) return evaluate(self._json_data['data'], {k: v[2] for k, v in pred.items()}) diff --git a/scripts/question_answering/question_answering.py b/scripts/question_answering/question_answering.py index 6b9fd64c57..d5ca9d92fd 100644 --- a/scripts/question_answering/question_answering.py +++ b/scripts/question_answering/question_answering.py @@ -18,10 +18,6 @@ # under the License. """BiDAF model blocks""" -from .attention_flow import AttentionFlow -from .bidaf import BidirectionalAttentionFlow -from .similarity_function import LinearSimilarity -from .utils import get_very_negative_number __all__ = ['BiDAFEmbedding', 'BiDAFModelingLayer', 'BiDAFOutputLayer', 'BiDAFModel'] @@ -32,6 +28,11 @@ from gluonnlp.model import ConvolutionalEncoder, Highway +from .attention_flow import AttentionFlow +from .bidaf import BidirectionalAttentionFlow +from .similarity_function import LinearSimilarity +from .utils import get_very_negative_number + class BiDAFEmbedding(HybridBlock): """BiDAFEmbedding is a class describing embeddings that are separately applied to question @@ -63,7 +64,7 @@ def __init__(self, batch_size, word_vocab, char_vocab, max_seq_len, output_size=None ) - self._word_embedding = nn.Embedding(prefix="predefined_embedding_layer", + self._word_embedding = nn.Embedding(prefix='predefined_embedding_layer', input_dim=len(word_vocab), output_dim=embedding_size) @@ -84,7 +85,7 @@ def init_embeddings(self, grad_req='null'): self._word_embedding.weight.set_data(self._word_vocab.embedding.idx_to_vec) self._word_embedding.collect_params().setattr('grad_req', grad_req) - def hybrid_forward(self, F, w, c, *args): + def hybrid_forward(self, F, w, c, *args): # pylint: disable=arguments-differ word_embedded = self._word_embedding(w) char_level_data = self._char_dense_embedding(c) char_level_data = self._dropout(char_level_data) @@ -145,7 +146,7 @@ def __init__(self, batch_size, input_dim=100, nlayers=2, biflag=True, self._modeling_layer = LSTM(hidden_size=input_dim, num_layers=nlayers, dropout=dropout, bidirectional=biflag, input_size=input_size) - def hybrid_forward(self, F, x, *args): + def hybrid_forward(self, F, x, *args): # pylint: disable=arguments-differ out = self._modeling_layer(x) return out @@ -252,11 +253,11 @@ def __init__(self, word_vocab, char_vocab, options, prefix=None, params=None): options.highway_num_layers, options.embedding_size, dropout=options.dropout, - prefix="context_embedding") + prefix='context_embedding') self.similarity_function = LinearSimilarity(array_1_dim=6 * options.embedding_size, array_2_dim=1, - combination="x,y,x*y") + combination='x,y,x*y') self.matrix_attention = AttentionFlow(self.similarity_function, options.batch_size, @@ -283,7 +284,7 @@ def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False, super(BiDAFModel, self).initialize(init, ctx, verbose, force_reinit) self.ctx_embedding.init_embeddings('null' if not self._options.train_unk_token else 'write') - def hybrid_forward(self, F, qw, cw, qc, cc, *args): + def hybrid_forward(self, F, qw, cw, qc, cc, *args): # pylint: disable=arguments-differ ctx_embedding_output = self.ctx_embedding(cw, cc) q_embedding_output = self.ctx_embedding(qw, qc) @@ -299,15 +300,16 @@ def hybrid_forward(self, F, qw, cw, qc, cc, *args): q_embedding_output) passage_question_similarity = passage_question_similarity.reshape( - shape=(self._options.batch_size, - self._options.ctx_max_len, - self._options.q_max_len)) + shape=(self._options.batch_size, + self._options.ctx_max_len, + self._options.q_max_len)) attention_layer_output = self.attention_layer(passage_question_similarity, ctx_embedding_output, q_embedding_output, q_mask, ctx_mask) + attention_layer_output = F.transpose(attention_layer_output, axes=(1, 0, 2)) # modeling layer expects seq_length x batch_size x channels diff --git a/scripts/question_answering/similarity_function.py b/scripts/question_answering/similarity_function.py index 181edf0694..5a6fc7f324 100644 --- a/scripts/question_answering/similarity_function.py +++ b/scripts/question_answering/similarity_function.py @@ -17,6 +17,8 @@ # specific language governing permissions and limitations # under the License. +"""Collection of general purpose similarity functions""" + import mxnet as mx from mxnet import gluon, initializer from mxnet.gluon import nn, Parameter @@ -37,7 +39,7 @@ class SimilarityFunction(gluon.HybridBlock): """ default_implementation = 'dot_product' - def hybrid_forward(self, F, array_1, array_2): + def hybrid_forward(self, F, array_1, array_2): # pylint: disable=arguments-differ # pylint: disable=arguments-differ """ Takes two tensors of the same shape, such as ``(batch_size, length_1, length_2, @@ -109,9 +111,9 @@ def __init__(self, activation='linear', **kwargs): super(BilinearSimilarity, self).__init__(**kwargs) - self._weight_matrix = Parameter(name="weight_matrix", + self._weight_matrix = Parameter(name='weight_matrix', shape=(array_1_dim, array_2_dim), init=mx.init.Xavier()) - self._bias = Parameter(name="bias", shape=(array_1_dim,), init=mx.init.Zero()) + self._bias = Parameter(name='bias', shape=(array_1_dim,), init=mx.init.Zero()) if activation == 'linear': self._activation = None @@ -176,15 +178,16 @@ def __init__(self, self._activation = nn.Activation(activation) with self.name_scope(): - self.weight_matrix = self.params.get("weight_matrix", + self.weight_matrix = self.params.get('weight_matrix', shape=(array_2_dim, array_1_dim), init=initializer.Uniform()) if use_bias: - self.bias = self.params.get("bias", + self.bias = self.params.get('bias', shape=(array_2_dim,), init=initializer.Zero()) def hybrid_forward(self, F, array_1, array_2, weight_matrix, bias=None): + # pylint: disable=arguments-differ combined_tensors = combine_tensors(F, self.combination, [array_1, array_2]) dot_product = F.FullyConnected(combined_tensors, weight_matrix, bias=bias, flatten=False, no_bias=not self.use_bias, num_hidden=self.array_2_dim) diff --git a/scripts/question_answering/tokenizer.py b/scripts/question_answering/tokenizer.py index c6b17a2040..72b11d78a6 100644 --- a/scripts/question_answering/tokenizer.py +++ b/scripts/question_answering/tokenizer.py @@ -16,8 +16,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import nltk + +"""Tokenizer for SQuAD dataset""" + import re +import nltk class BiDAFTokenizer: @@ -39,7 +42,7 @@ def __call__(self, sample, lower_case=False): """ sample = sample.replace('\n', ' ').replace(u'\u000A', '').replace(u'\u00A0', '') - tokens = [token.replace("''", '"').replace("``", '"') for token in + tokens = [token.replace('\'\'', '"').replace('``', '"') for token in nltk.word_tokenize(sample)] tokens = BiDAFTokenizer._process_tokens(tokens) tokens = [token for token in tokens if len(token) > 0] @@ -64,10 +67,10 @@ def _process_tokens(temp_tokens): List of updated tokens """ tokens = [] - splitters = ("-", "\u2212", "\u2014", "\u2013", "/", "~", '"', "'", "\u201C", - "\u2019", "\u201D", "\u2018", "\u00B0") + splitters = ('-', '\u2212', '\u2014', '\u2013', '/', '~', '"', '\'', '\u201C', + '\u2019', '\u201D', '\u2018', '\u00B0') for token in temp_tokens: - tokens.extend(re.split("([{}])".format("".join(splitters)), token)) + tokens.extend(re.split('([{}])'.format(''.join(splitters)), token)) return tokens diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py index 4935c54893..c38805ed71 100644 --- a/scripts/question_answering/train_question_answering.py +++ b/scripts/question_answering/train_question_answering.py @@ -16,6 +16,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +"""Main script to train BiDAF model""" + import argparse import copy import math @@ -37,7 +40,7 @@ from .data_processing import VocabProvider, SQuADTransform from .utils import PolyakAveraging from .performance_evaluator import PerformanceEvaluator -from .question_answering import * +from .question_answering import BiDAFModel from .question_id_mapper import QuestionIdMapper from .tokenizer import BiDAFTokenizer from .utils import logging_config @@ -66,6 +69,7 @@ def transform_dataset(dataset, vocab_provider, options, enable_filtering=False): transformer = SQuADTransform(vocab_provider, options.q_max_len, options.ctx_max_len, options.word_max_len, options.embedding_size) + i = 0 transformed_records = [] long_context = 0 long_question = 0 @@ -87,7 +91,7 @@ def transform_dataset(dataset, vocab_provider, options, enable_filtering=False): transformed_records.append(transformed_record) processed_dataset = SimpleDataset(transformed_records) - print("{}/{} records left. Too long context {}, too long query {}".format( + print('{}/{} records left. Too long context {}, too long query {}'.format( len(processed_dataset), i + 1, long_context, long_question)) return processed_dataset @@ -138,7 +142,7 @@ def get_record_per_answer_span(processed_dataset, options): num_workers=(multiprocessing.cpu_count() - len(get_context(options)) - 2)) - print("Total records for training: {}".format(len(labels))) + print('Total records for training: {}'.format(len(labels))) return loadable_data, dataloader @@ -207,10 +211,10 @@ def run_training(net, dataloader, ctx, options): hyperparameters = {'learning_rate': options.lr} if options.rho: - hyperparameters["rho"] = options.rho + hyperparameters['rho'] = options.rho trainer = Trainer(net.collect_params(), options.optimizer, hyperparameters, - kvstore="device", update_on_kvstore=False) + kvstore='device', update_on_kvstore=False) if options.resume_training: path = os.path.join(options.save_dir, @@ -228,10 +232,11 @@ def run_training(net, dataloader, ctx, options): max_iteration = -1 early_stop_tries = 0 - print("Starting training...") + print('Starting training...') for e in range(0 if not options.resume_training else options.resume_training, options.epochs): + i = 0 avg_loss *= 0 # Zero average loss of each epoch records_per_epoch_count = 0 @@ -252,8 +257,8 @@ def run_training(net, dataloader, ctx, options): losses = [] - for ri, qw, cw, qc, cc, l in zip(record_index, q_words, ctx_words, - q_chars, ctx_chars, label): + for _, qw, cw, qc, cc, l in zip(record_index, q_words, ctx_words, + q_chars, ctx_chars, label): with autograd.record(): begin, end = net(qw, cw, qc, cc) begin_end = l.split(axis=1, num_outputs=2, squeeze_axis=1) @@ -277,68 +282,45 @@ def run_training(net, dataloader, ctx, options): # predefined number of grad_req_add_mode which acts like batch_size counter if options.grad_req_add_mode > 0: if not iteration % options.grad_req_add_mode != 0 and \ - iteration != len(dataloader): + iteration != len(dataloader): iteration += 1 continue - scailing_coeff = len(ctx) * options.batch_size \ - if options.grad_req_add_mode == 0 else options.grad_req_add_mode - if options.lr_warmup_steps: trainer.set_learning_rate(warm_up_steps(iteration, options)) - if options.clip or options.train_unk_token: - trainer.allreduce_grads() - gradients = get_gradients(net, ctx[0], options) - - if options.clip: - gluon.utils.clip_global_norm(gradients, options.clip) - - if options.train_unk_token: - reset_embedding_gradients(net, ctx[0]) - - if len(ctx) > 1: - # in multi gpu mode we propagate new gradients to the rest of gpus - for name, parameter in net.collect_params().items(): - grads = parameter.list_grad() - source = grads[0] - destination = grads[1:] - - for dest in destination: - source.copyto(dest) - - trainer.update(scailing_coeff) - else: - trainer.step(scailing_coeff) + execute_trainer_step(net, trainer, ctx, options) if ema is not None: ema.update() if e == options.epochs - 1 and \ - options.log_interval > 0 and \ - iteration > 0 and iteration % options.log_interval == 0: + options.log_interval > 0 and \ + iteration > 0 and iteration % options.log_interval == 0: evaluate_options = copy.deepcopy(options) evaluate_options.epochs = iteration - result = run_evaluate_mode(evaluate_options, net, ema) - - print("Iteration {} evaluation results on dev dataset: {}".format(iteration, - result)) - if options.early_stop: - if result["f1"] > max_dev_f1: - max_dev_f1 = result["f1"] - max_dev_exact = result["exact_match"] - max_iteration = iteration - early_stop_tries = 0 + eval_result = run_evaluate_mode(evaluate_options, net, ema) + + print('Iteration {} evaluation results on dev dataset: {}'.format(iteration, + eval_result)) + if not options.early_stop: + continue + + if eval_result['f1'] > max_dev_f1: + max_dev_f1 = eval_result['f1'] + max_dev_exact = eval_result['exact_match'] + max_iteration = iteration + early_stop_tries = 0 + else: + early_stop_tries += 1 + if early_stop_tries < options.early_stop: + print('Results decreased for {} times'.format(early_stop_tries)) else: - if early_stop_tries < options.early_stop: - early_stop_tries += 1 - print("Results decreased for {} times".format(early_stop_tries)) - else: - print("Results decreased for {} times. Stop training. " - "Best results are stored at {} params file. F1={}, EM={}"\ - .format(options.early_stop + 1, max_iteration, - max_dev_f1, max_dev_exact)) - break + print('Results decreased for {} times. Stop training. ' + 'Best results are stored at {} params file. F1={}, EM={}' \ + .format(options.early_stop + 1, max_iteration, + max_dev_f1, max_dev_exact)) + break for l in losses: avg_loss += l.mean().as_in_context(avg_loss.context) @@ -353,8 +335,8 @@ def run_training(net, dataloader, ctx, options): avg_loss_scalar = avg_loss.asscalar() epoch_time = time() - e_start - print("\tEPOCH {:2}: train loss {:6.4f} | batch {:4} | lr {:5.3f} " - "| throughtput {:5.3f} of samples/sec | Time per epoch {:5.2f} seconds" + print('\tEPOCH {:2}: train loss {:6.4f} | batch {:4} | lr {:5.3f} ' + '| throughtput {:5.3f} of samples/sec | Time per epoch {:5.2f} seconds' .format(e, avg_loss_scalar, options.batch_size, trainer.learning_rate, records_per_epoch_count / epoch_time, epoch_time)) @@ -364,14 +346,14 @@ def run_training(net, dataloader, ctx, options): if options.terminate_training_on_reaching_F1_threshold: evaluate_options = copy.deepcopy(options) evaluate_options.epochs = e - result = run_evaluate_mode(evaluate_options, net, ema) + eval_result = run_evaluate_mode(evaluate_options, net, ema) - if result["f1"] >= options.terminate_training_on_reaching_F1_threshold: - print("Finishing training on {} epoch, because dev F1 score is >= required {}. {}" - .format(e, options.terminate_training_on_reaching_F1_threshold, result)) + if eval_result['f1'] >= options.terminate_training_on_reaching_F1_threshold: + print('Finishing training on {} epoch, because dev F1 score is >= required {}. {}' + .format(e, options.terminate_training_on_reaching_F1_threshold, eval_result)) break - print("Training time {:6.2f} seconds".format(time() - train_start)) + print('Training time {:6.2f} seconds'.format(time() - train_start)) def warm_up_steps(iteration, options): @@ -454,7 +436,49 @@ def is_fixed_embedding_layer(name): name : `str` Layer name to check """ - return True if "predefined_embedding_layer" in name else False + return True if 'predefined_embedding_layer' in name else False + + +def execute_trainer_step(net, trainer, ctx, options): + """Does training step if doesn't need to do gradient clipping or train unknown symbols. + + Parameters + ---------- + net : `Block` + Network to train + trainer : `Trainer` + Trainer + ctx: list + Context list + options: `SimpleNamespace` + Training options + """ + scailing_coeff = len(ctx) * options.batch_size \ + if options.grad_req_add_mode == 0 else options.grad_req_add_mode + + if options.clip or options.train_unk_token: + trainer.allreduce_grads() + gradients = get_gradients(net, ctx[0], options) + + if options.clip: + gluon.utils.clip_global_norm(gradients, options.clip) + + if options.train_unk_token: + reset_embedding_gradients(net, ctx[0]) + + if len(ctx) > 1: + # in multi gpu mode we propagate new gradients to the rest of gpus + for _, parameter in net.collect_params().items(): + grads = parameter.list_grad() + source = grads[0] + destination = grads[1:] + + for dest in destination: + source.copyto(dest) + + trainer.update(scailing_coeff) + else: + trainer.step(scailing_coeff) def save_model_parameters(params, epoch, options): @@ -508,7 +532,7 @@ def save_transformed_dataset(dataset, path): path : `str` Saving path """ - pickle.dump(dataset, open(path, "wb")) + pickle.dump(dataset, open(path, 'wb')) def load_transformed_dataset(path): @@ -524,7 +548,7 @@ def load_transformed_dataset(path): processed_dataset : SimpleDataset Transformed dataset """ - processed_dataset = pickle.load(open(path, "rb")) + processed_dataset = pickle.load(open(path, 'rb')) return processed_dataset @@ -569,11 +593,11 @@ def run_training_mode(options): enable_filtering=True) save_transformed_dataset(transformed_dataset, options.preprocessed_dataset_path) - train_dataset, train_dataloader = get_record_per_answer_span(transformed_dataset, options) + _, train_dataloader = get_record_per_answer_span(transformed_dataset, options) word_vocab, char_vocab = get_vocabs(vocab_provider, options=options) ctx = get_context(options) - net = BiDAFModel(word_vocab, char_vocab, options, prefix="bidaf") + net = BiDAFModel(word_vocab, char_vocab, options, prefix='bidaf') net.initialize(init.Xavier(), ctx=ctx) net.hybridize(static_alloc=True) @@ -581,7 +605,7 @@ def run_training_mode(options): net.collect_params().setattr('grad_req', 'add') if options.resume_training: - print("Resuming training from {} epoch".format(options.resume_training)) + print('Resuming training from {} epoch'.format(options.resume_training)) params_path = os.path.join(options.save_dir, 'epoch{:d}.params'.format(int(options.resume_training) - 1)) net.load_parameters(params_path, ctx) @@ -624,7 +648,7 @@ def run_evaluate_mode(options, existing_net=None, existing_ema=None): evaluator = PerformanceEvaluator(BiDAFTokenizer(), transformed_dataset, dataset._read_data(), mapper) - net = BiDAFModel(word_vocab, char_vocab, options, prefix="bidaf") + net = BiDAFModel(word_vocab, char_vocab, options, prefix='bidaf') if existing_ema is not None: save_model_parameters(existing_ema.get_params(), options.epochs, options) @@ -651,9 +675,9 @@ def get_args(): parser.add_argument('--evaluate', default=False, action='store_true', help='Run evaluation on dev dataset') parser.add_argument('--preprocessed_dataset_path', type=str, - default="preprocessed_dataset.p", help='Path to preprocessed dataset') + default='preprocessed_dataset.p', help='Path to preprocessed dataset') parser.add_argument('--preprocessed_val_dataset_path', type=str, - default="preprocessed_val_dataset.p", + default='preprocessed_val_dataset.p', help='Path to preprocessed validation dataset') parser.add_argument('--epochs', type=int, default=12, help='Upper epoch limit') parser.add_argument('--embedding_size', type=int, default=100, @@ -717,11 +741,11 @@ def get_args(): help='Enable rolling gradient mode, where batch size is always 1 and ' 'gradients are accumulated using single GPU') - args = parser.parse_args() - return args + options = parser.parse_args() + return options -if __name__ == "__main__": +if __name__ == '__main__': args = get_args() args.batch_size = int(args.batch_size / len(get_context(args))) print(args) @@ -729,18 +753,17 @@ def get_args(): if args.preprocess: if not args.preprocessed_dataset_path: - logging.error("Preprocessed_data_path attribute is not provided") + logging.error('Preprocessed_data_path attribute is not provided') exit(1) - print("Running in preprocessing mode") + print('Running in preprocessing mode') run_preprocess_mode(args) if args.train: - print("Running in training mode") + print('Running in training mode') run_training_mode(args) if args.evaluate: - print("Running in evaluation mode") + print('Running in evaluation mode') result = run_evaluate_mode(args) - print("Evaluation results on dev dataset: {}".format(result)) - + print('Evaluation results on dev dataset: {}'.format(result)) diff --git a/scripts/question_answering/utils.py b/scripts/question_answering/utils.py index 9af5fc5d41..fc1ec73ad0 100644 --- a/scripts/question_answering/utils.py +++ b/scripts/question_answering/utils.py @@ -17,6 +17,8 @@ # specific language governing permissions and limitations # under the License. +"""Set of utility methods for question answering models""" + import inspect import logging import os @@ -95,7 +97,7 @@ def extend_to_batch_size(batch_size, prototype, fill_value=0): if batch_size == prototype.shape[0]: return prototype - new_shape = (batch_size - prototype.shape[0], ) + prototype.shape[1:] + new_shape = (batch_size - prototype.shape[0],) + prototype.shape[1:] dummy_elements = nd.full(val=fill_value, shape=new_shape, dtype=prototype.dtype, ctx=prototype.context) return nd.concat(prototype, dummy_elements, dim=0) @@ -127,8 +129,6 @@ def _get_combination_dim(combination, tensor_dims): return tensor_dims[index] else: first_tensor_dim = _get_combination_dim(combination[0], tensor_dims) - second_tensor_dim = _get_combination_dim(combination[2], tensor_dims) - operation = combination[1] return first_tensor_dim @@ -249,7 +249,7 @@ def _last_dimension_applicator(F, if mask is not None: shape_difference = len(tensor_shape) - len(mask_shape) - for i in range(0, shape_difference): + for _ in range(0, shape_difference): mask = mask.expand_dims(1) mask = mask.broadcast_to(shape=tensor_shape) mask = mask.reshape(shape=(-1, mask_shape[-1])) @@ -270,8 +270,8 @@ def last_dim_softmax(F, tensor, mask, tensor_shape, mask_shape, epsilon): def last_dim_log_softmax(F, tensor, mask, tensor_shape, mask_shape): """ Takes a tensor with 3 or more dimensions and does a masked log softmax over the last dimension. - We assume the tensor has shape ``(batch_size, ..., sequence_length)`` and that the mask (if given) - has shape ``(batch_size, sequence_length)``. + We assume the tensor has shape ``(batch_size, ..., sequence_length)`` and that the mask + (if given) has shape ``(batch_size, sequence_length)``. """ return _last_dimension_applicator(F, masked_log_softmax, tensor, mask, tensor_shape, mask_shape) @@ -330,6 +330,7 @@ def replace_masked_values(F, tensor, mask, replace_with): class PolyakAveraging: """Class to do Polyak averaging based on this paper http://www.meyn.ece.ufl.edu/archive/spm_files/Courses/ECE555-2011/555media/poljud92.pdf""" + def __init__(self, params, decay): self._params = params self._decay = decay