diff --git a/gluonnlp/data/question_answering.py b/gluonnlp/data/question_answering.py index ffc29766f0..5016977a16 100644 --- a/gluonnlp/data/question_answering.py +++ b/gluonnlp/data/question_answering.py @@ -80,7 +80,7 @@ def __init__(self, segment='train', root=os.path.join('~', '.mxnet', 'datasets', self._segment = segment self._get_data() - super(SQuAD, self).__init__(self._read_data()) + super(SQuAD, self).__init__(SQuAD._get_records(self._read_data())) def _get_data(self): """Load data from the file. Does nothing if data was loaded before @@ -116,9 +116,9 @@ def _read_data(self): _, data_file_name, _ = self._data_file[self._segment] with open(os.path.join(self._root, data_file_name)) as f: - samples = json.load(f) + json_data = json.load(f) - return SQuAD._get_records(samples) + return json_data @staticmethod def _get_records(json_dict): diff --git a/scripts/question_answering/attention_flow.py b/scripts/question_answering/attention_flow.py new file mode 100644 index 0000000000..fd7c256df1 --- /dev/null +++ b/scripts/question_answering/attention_flow.py @@ -0,0 +1,70 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Attention Flow Layer""" +from mxnet import gluon + +from .similarity_function import DotProductSimilarity + + +class AttentionFlow(gluon.HybridBlock): + """ + This ``block`` takes two ndarrays as input and returns a ndarray of attentions. + + We compute the similarity between each row in each matrix and return unnormalized similarity + scores. Because these scores are unnormalized, we don't take a mask as input; it's up to the + caller to deal with masking properly when this output is used. + + By default similarity is computed with a dot product, but you can alternatively use a + parameterized similarity function if you wish. + + + Input: + - ndarray_1: ``(batch_size, num_rows_1, embedding_dim)`` + - ndarray_2: ``(batch_size, num_rows_2, embedding_dim)`` + + Output: + - ``(batch_size, num_rows_1, num_rows_2)`` + + Parameters + ---------- + similarity_function: ``SimilarityFunction``, optional (default=``DotProductSimilarity``) + The similarity function to use when computing the attention. + """ + def __init__(self, similarity_function, batch_size, passage_length, + question_length, embedding_size, **kwargs): + super(AttentionFlow, self).__init__(**kwargs) + + self._similarity_function = similarity_function or DotProductSimilarity() + self._batch_size = batch_size + self._passage_length = passage_length + self._question_length = question_length + self._embedding_size = embedding_size + + def hybrid_forward(self, F, matrix_1, matrix_2): + # pylint: disable=arguments-differ + tiled_matrix_1 = matrix_1.expand_dims(2).broadcast_to(shape=(self._batch_size, + self._passage_length, + self._question_length, + self._embedding_size)) + tiled_matrix_2 = matrix_2.expand_dims(1).broadcast_to(shape=(self._batch_size, + self._passage_length, + self._question_length, + self._embedding_size)) + return self._similarity_function(tiled_matrix_1, tiled_matrix_2) diff --git a/scripts/question_answering/bidaf.py b/scripts/question_answering/bidaf.py new file mode 100644 index 0000000000..fb14a21a9b --- /dev/null +++ b/scripts/question_answering/bidaf.py @@ -0,0 +1,121 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Bidirectional attention flow layer""" +from mxnet import gluon +import numpy as np + +from .utils import last_dim_softmax, weighted_sum, replace_masked_values, masked_softmax + + +class BidirectionalAttentionFlow(gluon.HybridBlock): + """ + This class implements Minjoon Seo's `Bidirectional Attention Flow model + `_ + for answering reading comprehension questions (ICLR 2017). + """ + + def __init__(self, + batch_size, + passage_length, + question_length, + encoding_dim, + **kwargs): + super(BidirectionalAttentionFlow, self).__init__(**kwargs) + + self._batch_size = batch_size + self._passage_length = passage_length + self._question_length = question_length + self._encoding_dim = encoding_dim + + def _get_big_negative_value(self): + """Provides maximum negative Float32 value + Returns + ------- + value : float32 + Maximum negative float32 value + """ + return np.finfo(np.float32).min + + def _get_small_positive_value(self): + """Provides minimal possible Float32 value + Returns + ------- + value : float32 + Minimal float32 value + """ + return np.finfo(np.float32).eps + + def hybrid_forward(self, F, passage_question_similarity, + encoded_passage, encoded_question, question_mask, passage_mask): + # pylint: disable=arguments-differ + # Shape: (batch_size, passage_length, question_length) + passage_question_similarity_shape = (self._batch_size, self._passage_length, + self._question_length) + + question_mask_shape = (self._batch_size, self._question_length) + # Shape: (batch_size, passage_length, question_length) + passage_question_attention = last_dim_softmax(F, + passage_question_similarity, + question_mask, + passage_question_similarity_shape, + question_mask_shape, + epsilon=self._get_small_positive_value()) + # Shape: (batch_size, passage_length, encoding_dim) + encoded_question_shape = (self._batch_size, self._question_length, self._encoding_dim) + passage_question_attention_shape = (self._batch_size, self._passage_length, + self._question_length) + passage_question_vectors = weighted_sum(F, encoded_question, passage_question_attention, + encoded_question_shape, + passage_question_attention_shape) + + # We replace masked values with something really negative here, so they don't affect the + # max below. + masked_similarity = passage_question_similarity if question_mask is None else \ + replace_masked_values(F, + passage_question_similarity, + question_mask.expand_dims(1), + replace_with=self._get_big_negative_value()) + + # Shape: (batch_size, passage_length) + question_passage_similarity = masked_similarity.max(axis=-1) + + # Shape: (batch_size, passage_length) + question_passage_attention = masked_softmax(F, question_passage_similarity, passage_mask, + epsilon=self._get_small_positive_value()) + + # Shape: (batch_size, encoding_dim) + encoded_passage_shape = (self._batch_size, self._passage_length, self._encoding_dim) + question_passage_attention_shape = (self._batch_size, self._passage_length) + question_passage_vector = weighted_sum(F, encoded_passage, question_passage_attention, + encoded_passage_shape, + question_passage_attention_shape) + + # Shape: (batch_size, passage_length, encoding_dim) + tiled_question_passage_vector = question_passage_vector.expand_dims(1) + + # Shape: (batch_size, passage_length, encoding_dim * 4) + final_merged_passage = F.concat(encoded_passage, + passage_question_vectors, + encoded_passage * passage_question_vectors, + F.broadcast_mul(encoded_passage, + tiled_question_passage_vector), + dim=-1) + + return final_merged_passage diff --git a/scripts/question_answering/data_processing.py b/scripts/question_answering/data_processing.py index ad5ec20949..97d1d48a23 100644 --- a/scripts/question_answering/data_processing.py +++ b/scripts/question_answering/data_processing.py @@ -19,46 +19,34 @@ # pylint: disable= """SQuAD data preprocessing.""" -__all__ = ['SQuADTransform', 'VocabProvider', 'preprocess_dataset'] -import re -import numpy as np +__all__ = ['SQuADTransform', 'VocabProvider'] +import pickle +from os.path import isfile +import numpy as np from mxnet import nd -from mxnet.gluon.data import SimpleDataset +import gluonnlp as nlp from gluonnlp import Vocab, data from gluonnlp.data.batchify import Pad - - -def preprocess_dataset(dataset, question_max_length, context_max_length): - """Process SQuAD dataset by creating NDArray version of data - - :param Dataset dataset: SQuAD dataset - :param int question_max_length: Maximum length of question (padded or trimmed to that size) - :param int context_max_length: Maximum length of context (padded or trimmed to that size) - - Returns - ------- - SimpleDataset - Dataset of preprocessed records - """ - vocab_provider = VocabProvider(dataset) - transformer = SQuADTransform(vocab_provider, question_max_length, context_max_length) - processed_dataset = SimpleDataset(dataset.transform(transformer, lazy=False)) - return processed_dataset +from .tokenizer import BiDAFTokenizer class SQuADTransform(object): """SQuADTransform class responsible for converting text data into NDArrays that can be later feed into DataProvider """ - def __init__(self, vocab_provider, question_max_length, context_max_length): - self._word_vocab = vocab_provider.get_word_level_vocab() + + def __init__(self, vocab_provider, question_max_length, context_max_length, + max_chars_per_word, embedding_size): + self._word_vocab = vocab_provider.get_word_level_vocab(embedding_size) self._char_vocab = vocab_provider.get_char_level_vocab() + self._tokenizer = vocab_provider.get_tokenizer() self._question_max_length = question_max_length self._context_max_length = context_max_length + self._max_chars_per_word = max_chars_per_word self._padder = Pad() @@ -67,48 +55,213 @@ def __call__(self, record_index, question_id, question, context, answer_list, """ Method converts text into numeric arrays based on Vocabulary. Answers are not processed, as they are not needed in input + + Parameters + ---------- + + record_index: int + Index of the record + question_id: str + Question Id + question: str + Question + context: str + Context + answer_list: list[str] + List of answers + answer_start_list: list[int] + List of start indices of answers + + Returns + ------- + record: Tuple + A tuple containing record_index [int], question_id [str], question_words_nd [NDArray], + context_words_nd [NDArray], question_chars_nd [NDArray[, context_chars_nd [NDArray], + answer_spans [list[Tuple]] """ - question_words = self._word_vocab[question.split()[:self._question_max_length]] - context_words = self._word_vocab[context.split()[:self._context_max_length]] + question_tokens = self._tokenizer(question, lower_case=True) + context_tokens = self._tokenizer(context, lower_case=True) + + question_words = self._word_vocab[question_tokens[:self._question_max_length]] + context_words = self._word_vocab[context_tokens[:self._context_max_length]] - question_chars = [self._char_vocab[list(iter(word))] - for word in question.split()[:self._question_max_length]] + question_chars = [self._char_vocab[[character.lower() for character in word]] + for word in question_tokens[:self._question_max_length]] - context_chars = [self._char_vocab[list(iter(word))] - for word in context.split()[:self._context_max_length]] + context_chars = [self._char_vocab[[character.lower() for character in word]] + for word in context_tokens[:self._context_max_length]] - question_words_nd = nd.array(question_words, dtype=np.int32) + question_words_nd = self._pad_to_max_word_length(question_words, self._question_max_length) question_chars_nd = self._padder(question_chars) + question_chars_nd = self._pad_to_max_char_length(question_chars_nd, + self._question_max_length) - context_words_nd = nd.array(context_words, dtype=np.int32) + context_words_nd = self._pad_to_max_word_length(context_words, self._context_max_length) context_chars_nd = self._padder(context_chars) + context_chars_nd = self._pad_to_max_char_length(context_chars_nd, self._context_max_length) + + answer_spans = SQuADTransform._get_answer_spans(context, context_tokens, answer_list, + answer_start_list) - answer_spans = SQuADTransform._get_answer_spans(answer_list, answer_start_list) + return (record_index, question_id, question_words_nd, context_words_nd, + question_chars_nd, context_chars_nd, answer_spans) + + @staticmethod + def _get_answer_spans(context, context_tokens, answer_list, answer_start_list): + """Find all answer spans from the context, returning start_index and end_index. + Each index is a index of a token + + Parameters + ---------- + context: str + Context + context_tokens: list[str] + Tokenized context + answer_list: list[str] + List of answers + answer_start_list: list[int] + List of answers start indices - return record_index, question_id, question_words_nd, context_words_nd, \ - question_chars_nd, context_chars_nd, answer_spans + Returns + ------- + answer_spans: List[Tuple] + List of Tuple(answer_start_index answer_end_index) per question + """ + answer_spans = [] + # SQuAD answers doesn't always match to used tokens in the context. Sometimes there is only + # a partial match. We use the same method as used in original implementation: + # 1. Find char index range for all tokens of context + # 2. Foreach answer + # 2.1 Find char index range for the answer (not tokenized) + # 2.2 Find Context token indices which char indices contains answer char indices + # 2.3. Return first and last token indices + context_char_indices = SQuADTransform.get_char_indices(context, context_tokens) + + for answer_start_char_index, answer in zip(answer_start_list, answer_list): + answer_token_indices = [] + answer_end_char_index = answer_start_char_index + len(answer) + + for context_token_index, context_char_span in enumerate(context_char_indices): + if not (answer_end_char_index <= context_char_span[0] or + answer_start_char_index >= context_char_span[1]): + answer_token_indices.append(context_token_index) + + if len(answer_token_indices) == 0: + print('Warning: Answer {} not found for context {}'.format(answer, context)) + else: + answer_span = (answer_token_indices[0], + answer_token_indices[len(answer_token_indices) - 1]) + answer_spans.append(answer_span) + + if len(answer_spans) == 0: + print('Warning: No answers found for context {}'.format(context_tokens)) + + return answer_spans @staticmethod - def _get_answer_spans(answer_list, answer_start_list): - """Find all answer spans from the context, returning start_index and end_index + def get_char_indices(text, text_tokens): + """Match token with character indices - :param list[str] answer_list: List of all answers - :param list[int] answer_start_list: List of all answers' start indices + Parameters + ---------- + text: str + Text + text_tokens: list[str] + Tokens of the text Returns ------- - List[Tuple] - list of Tuple(answer_start_index answer_end_index) per question + char_indices_per_token: List[Tuple] + List of (start_index, end_index) of characters where the position equals to token index """ - return [(answer_start_list[i], answer_start_list[i] + len(answer)) - for i, answer in enumerate(answer_list)] + char_indices_per_token = [] + current_index = 0 + text_lowered = text.lower() + + for token in text_tokens: + current_index = text_lowered.find(token, current_index) + char_indices_per_token.append((current_index, current_index + len(token))) + current_index += len(token) + + return char_indices_per_token + + def _pad_to_max_char_length(self, item, max_item_length): + """Pads all tokens to maximum size + + Parameters + ---------- + item: NDArray + Matrix of indices + max_item_length: int + Maximum length of a token + + Returns + ------- + NDArray + Padded NDArray + """ + # expand dimensions to 4 and turn to float32, because nd.pad can work only with 4 dims + data_expanded = item.reshape(1, 1, item.shape[0], item.shape[1]).astype(np.float32) + data_padded = nd.pad(data_expanded, + mode='constant', + pad_width=[0, 0, 0, 0, 0, max_item_length - item.shape[0], + 0, self._max_chars_per_word - item.shape[1]], + constant_value=0) + + # reshape back to original dimensions with the last dimension of max_item_length + # We also convert to float32 because it will be necessary later for processing + data_reshaped_back = data_padded.reshape(max_item_length, + self._max_chars_per_word).astype(np.float32) + return data_reshaped_back + + @staticmethod + def _pad_to_max_word_length(item, max_length): + """Pads sentences to maximum length + + Parameters + ---------- + item: NDArray + Vector of words + max_length: int + Maximum length of question/context + + Returns + ------- + NDArray + Padded vector of tokens + """ + data_nd = nd.array(item, dtype=np.float32) + # expand dimensions to 4 and turn to float32, because nd.pad can work only with 4 dims + data_expanded = data_nd.reshape(1, 1, 1, data_nd.shape[0]) + data_padded = nd.pad(data_expanded, + mode='constant', + pad_width=[0, 0, 0, 0, 0, 0, 0, max_length - data_nd.shape[0]], + constant_value=0) + # reshape back to original dimensions with the last dimension of max_length + # We also convert to float32 because it will be necessary later for processing + data_reshaped_back = data_padded.reshape(max_length).astype(np.float32) + return data_reshaped_back class VocabProvider(object): """Provides word level and character level vocabularies """ - def __init__(self, dataset): - self._dataset = dataset + + def __init__(self, datasets, options, tokenizer=BiDAFTokenizer()): + self._datasets = datasets + self._options = options + self._tokenizer = tokenizer + + def get_tokenizer(self): + """Provides tokenizer used to create vocab + + Returns + ------- + tokenizer: Tokenizer + Tokenizer + + """ + return self._tokenizer def get_char_level_vocab(self): """Provides character level vocabulary @@ -118,9 +271,21 @@ def get_char_level_vocab(self): Vocab Character level vocabulary """ - return VocabProvider._create_squad_vocab(iter, self._dataset) + if self._options.char_vocab_path and isfile(self._options.char_vocab_path): + return pickle.load(open(self._options.char_vocab_path, 'rb')) - def get_word_level_vocab(self): + all_chars = [] + for dataset in self._datasets: + all_chars.extend(self._get_all_char_tokens(dataset)) + + char_level_vocab = VocabProvider._create_squad_vocab(all_chars) + + if self._options.char_vocab_path: + pickle.dump(char_level_vocab, open(self._options.char_vocab_path, 'wb')) + + return char_level_vocab + + def get_word_level_vocab(self, embedding_size): """Provides word level vocabulary Returns @@ -129,19 +294,92 @@ def get_word_level_vocab(self): Word level vocabulary """ - def simple_tokenize(source_str, token_delim=' ', seq_delim='\n'): - return list(filter(None, re.split(token_delim + '|' + seq_delim, source_str))) + if self._options.word_vocab_path and isfile(self._options.word_vocab_path): + return pickle.load(open(self._options.word_vocab_path, 'rb')) + + all_words = [] + for dataset in self._datasets: + all_words.extend(self._get_all_word_tokens(dataset)) + + word_level_vocab = VocabProvider._create_squad_vocab(all_words) + word_level_vocab.set_embedding( + nlp.embedding.create('glove', source='glove.6B.{}d'.format(embedding_size))) + + count = 0 + + for i in range(len(word_level_vocab)): + if (word_level_vocab.embedding.idx_to_vec[i].sum() != 0).asscalar(): + count += 1 + + print('{}/{} words have embeddings'.format(count, len(word_level_vocab))) + + if self._options.word_vocab_path: + pickle.dump(word_level_vocab, open(self._options.word_vocab_path, 'wb')) - return VocabProvider._create_squad_vocab(simple_tokenize, self._dataset) + return word_level_vocab @staticmethod - def _create_squad_vocab(tokenization_fn, dataset): - all_tokens = [] + def _create_squad_vocab(all_tokens): + """Provides vocabulary based on list of tokens - for data_item in dataset: - all_tokens.extend(tokenization_fn(data_item[1])) - all_tokens.extend(tokenization_fn(data_item[2])) + Parameters + ---------- + + all_tokens: List[str] + List of all tokens + Returns + ------- + Vocab + Vocabulary + """ counter = data.count_tokens(all_tokens) vocab = Vocab(counter) return vocab + + def _get_all_word_tokens(self, dataset): + """Provides all words from context and question of the dataset + + Parameters + ---------- + + dataset: SimpleDataset + Dataset of SQuAD + + Returns + ------- + all_tokens: list[str] + List of all words + """ + all_tokens = [] + + for data_item in dataset: + all_tokens.extend(self._tokenizer(data_item[2], lower_case=True)) + all_tokens.extend(self._tokenizer(data_item[3], lower_case=True)) + + return all_tokens + + def _get_all_char_tokens(self, dataset): + """Provides all characters from context and question of the dataset + + Parameters + ---------- + + dataset: SimpleDataset + Dataset of SQuAD + + Returns + ------- + all_tokens: list[str] + List of all characters + """ + all_tokens = [] + + for data_item in dataset: + for character in data_item[2]: + all_tokens.extend(character.lower()) + + for character in data_item[3]: + all_tokens.extend(character.lower()) + + return all_tokens diff --git a/scripts/question_answering/official_squad_eval_script.py b/scripts/question_answering/official_squad_eval_script.py new file mode 100644 index 0000000000..b2a8948f1e --- /dev/null +++ b/scripts/question_answering/official_squad_eval_script.py @@ -0,0 +1,143 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Official evaluation script for v1.1 of the SQuAD dataset. """ +from __future__ import print_function +from collections import Counter +import string +import re +import argparse +import json +import sys + + +def f1_score(prediction, ground_truth): + """Calculate F1 score + + Parameters + ---------- + prediction : str + Prediction + + ground_truth : str + Ground truth + + Returns + ------- + F1 : float + F1 score + """ + prediction_tokens = normalize_answer(prediction).split() + ground_truth_tokens = normalize_answer(ground_truth).split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def normalize_answer(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + def remove_articles(text): + return re.sub(r'\b(a|an|the)\b', ' ', text) + + def white_space_fix(text): + return ' '.join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return ''.join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def exact_match_score(prediction, ground_truth): + return (normalize_answer(prediction) == normalize_answer(ground_truth)) + + +def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): + scores_for_ground_truths = [] + for ground_truth in ground_truths: + score = metric_fn(prediction, ground_truth) + scores_for_ground_truths.append(score) + return max(scores_for_ground_truths) + + +def evaluate(dataset, predictions): + """Evaluate dataset against predictions + + Parameters + ---------- + dataset : `dict` + Dictionary containing JSON of SQuAD 1.1 dataset + + predictions : `dict` + Map of Question id and prediction + + Returns + ------- + result : `dict` + Dictionary with F1 and Exact Match scores + """ + f1 = exact_match = total = 0 + for article in dataset: + for paragraph in article['paragraphs']: + for qa in paragraph['qas']: + total += 1 + if qa['id'] not in predictions: + message = 'Unanswered question ' + qa['id'] + \ + ' will receive score 0.' + print(message, file=sys.stderr) + continue + ground_truths = list(map(lambda x: x['text'], qa['answers'])) + prediction = predictions[qa['id']] + exact_match += metric_max_over_ground_truths( + exact_match_score, prediction, ground_truths) + f1 += metric_max_over_ground_truths( + f1_score, prediction, ground_truths) + + exact_match = 100.0 * exact_match / total + f1 = 100.0 * f1 / total + + return {'exact_match': exact_match, 'f1': f1} + + +if __name__ == '__main__': + expected_version = '1.1' + parser = argparse.ArgumentParser( + description='Evaluation for SQuAD ' + expected_version) + parser.add_argument('dataset_file', help='Dataset file') + parser.add_argument('prediction_file', help='Prediction File') + args = parser.parse_args() + with open(args.dataset_file) as dataset_file: + dataset_json = json.load(dataset_file) + if (dataset_json['version'] != expected_version): + print('Evaluation expects v-' + expected_version + + ', but got dataset with v-' + dataset_json['version'], + file=sys.stderr) + dataset_data = dataset_json['data'] + with open(args.prediction_file) as prediction_file: + predictions_json = json.load(prediction_file) + print(json.dumps(evaluate(dataset_data, predictions_json))) diff --git a/scripts/question_answering/performance_evaluator.py b/scripts/question_answering/performance_evaluator.py new file mode 100644 index 0000000000..fdd73d0b6a --- /dev/null +++ b/scripts/question_answering/performance_evaluator.py @@ -0,0 +1,180 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Performance evaluator - a proxy class used for plugging in official validation script""" +import multiprocessing +from mxnet import nd, gluon, cpu +from mxnet.gluon.data import DataLoader, ArrayDataset + +from .data_processing import SQuADTransform +from .official_squad_eval_script import evaluate +from .utils import extend_to_batch_size + + +class PerformanceEvaluator: + """Plugin to run prediction and performance evaluation via official eval script""" + def __init__(self, tokenizer, evaluation_dataset, json_data, question_id_mapper): + self._tokenizer = tokenizer + self._evaluation_dataset = evaluation_dataset + self._json_data = json_data + self._mapper = question_id_mapper + + def evaluate_performance(self, net, ctx, options): + """Get results of evaluation by official evaluation script + + Parameters + ---------- + net : `Block` + Network + ctx : `Context` + Execution context + options : `Namespace` + Training arguments + + Returns + ------- + data : `dict` + Returns a dictionary of {'exact_match': , 'f1': } + """ + + pred = {} + + # Allows to ensure that start index is always <= than end index + for _ in ctx: + answer_mask_matrix = nd.zeros(shape=(1, options.ctx_max_len, options.ctx_max_len), + ctx=cpu(0)) + for idx in range(options.answer_max_len): + answer_mask_matrix += nd.eye(N=options.ctx_max_len, M=options.ctx_max_len, + k=idx, ctx=cpu(0)) + + eval_dataset = ArrayDataset([(self._mapper.question_id_to_idx[r[1]], r[2], r[3], r[4], r[5]) + for r in self._evaluation_dataset]) + eval_dataloader = DataLoader(eval_dataset, + batch_size=len(ctx) * options.batch_size, + last_batch='keep', + pin_memory=True, + num_workers=(multiprocessing.cpu_count() - len(ctx) - 2)) + + for data in eval_dataloader: + record_index, q_words, ctx_words, q_chars, ctx_chars = data + + record_index = extend_to_batch_size(options.batch_size * len(ctx), record_index, -1) + q_words = extend_to_batch_size(options.batch_size * len(ctx), q_words) + ctx_words = extend_to_batch_size(options.batch_size * len(ctx), ctx_words) + q_chars = extend_to_batch_size(options.batch_size * len(ctx), q_chars) + ctx_chars = extend_to_batch_size(options.batch_size * len(ctx), ctx_chars) + + record_index = gluon.utils.split_and_load(record_index, ctx, even_split=False) + q_words = gluon.utils.split_and_load(q_words, ctx, even_split=False) + ctx_words = gluon.utils.split_and_load(ctx_words, ctx, even_split=False) + q_chars = gluon.utils.split_and_load(q_chars, ctx, even_split=False) + ctx_chars = gluon.utils.split_and_load(ctx_chars, ctx, even_split=False) + + outs = [] + + for ri, qw, cw, qc, cc in zip(record_index, q_words, ctx_words, + q_chars, ctx_chars): + begin, end = net(qw, cw, qc, cc) + outs.append((ri.as_in_context(cpu(0)), + begin.as_in_context(cpu(0)), + end.as_in_context(cpu(0)))) + + for out in outs: + ri = out[0] + start = out[1].softmax(axis=1) + end = out[2].softmax(axis=1) + start_end_span = PerformanceEvaluator._get_indices(start, end, answer_mask_matrix) + + # iterate over batches + for idx, start_end in zip(ri, start_end_span): + idx = int(idx.asscalar()) + start = int(start_end[0].asscalar()) + end = int(start_end[1].asscalar()) + + if idx in self._mapper.idx_to_question_id: + question_id = self._mapper.idx_to_question_id[idx] + pred[question_id] = (start, end, self.get_text_result(idx, (start, end))) + + if options.save_prediction_path: + with open(options.save_prediction_path, 'w') as f: + for item in pred.items(): + f.write('{}: {}-{} Answer: {}\n'.format(item[0], item[1][0], + item[1][1], item[1][2])) + + return evaluate(self._json_data['data'], {k: v[2] for k, v in pred.items()}) + + def get_text_result(self, idx, answer_span): + """Converts answer span into actual text from paragraph + + Parameters + ---------- + idx : `int` + Question index + answer_span : `Tuple` + Answer span (start_index, end_index) + + Returns + ------- + text : `str` + A chunk of text for provided answer_span or None if answer span cannot be provided + """ + + start, end = answer_span + + if start > end: + return '' + + question_id = self._mapper.idx_to_question_id[idx] + context = self._mapper.question_id_to_context[question_id] + context_tokens = self._tokenizer(context, lower_case=True) + indices = SQuADTransform.get_char_indices(context, context_tokens) + + # get text from cutting string from the initial context + # because tokens are hard to combine together + text = context[indices[start][0]:indices[end][1]] + return text + + @staticmethod + def _get_indices(begin, end, answer_mask_matrix): + r"""Select the begin and end position of answer span. + + At inference time, the predicted span (s, e) is chosen such that + begin_hat[s] * end_hat[e] is maximized and s ≤ e. + + Parameters + ---------- + begin : NDArray + input tensor with shape `(batch_size, context_sequence_length)` + end : NDArray + input tensor with shape `(batch_size, context_sequence_length)` + + Returns + ------- + prediction: Tuple + Tuple containing first and last token indices of the answer + """ + begin_hat = begin.reshape(begin.shape + (1,)) + end_hat = end.reshape(end.shape + (1,)) + end_hat = end_hat.transpose(axes=(0, 2, 1)) + + result = nd.batch_dot(begin_hat, end_hat) * answer_mask_matrix.slice( + begin=(0, 0, 0), end=(1, begin_hat.shape[1], begin_hat.shape[1])) + yp1 = result.max(axis=2).argmax(axis=1, keepdims=True).astype('int32') + yp2 = result.max(axis=1).argmax(axis=1, keepdims=True).astype('int32') + return nd.concat(yp1, yp2, dim=-1) diff --git a/scripts/question_answering/question_answering.py b/scripts/question_answering/question_answering.py new file mode 100644 index 0000000000..d5ca9d92fd --- /dev/null +++ b/scripts/question_answering/question_answering.py @@ -0,0 +1,320 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""BiDAF model blocks""" + +__all__ = ['BiDAFEmbedding', 'BiDAFModelingLayer', 'BiDAFOutputLayer', 'BiDAFModel'] + +from mxnet import initializer +from mxnet.gluon import HybridBlock +from mxnet.gluon import nn +from mxnet.gluon.rnn import LSTM + +from gluonnlp.model import ConvolutionalEncoder, Highway + +from .attention_flow import AttentionFlow +from .bidaf import BidirectionalAttentionFlow +from .similarity_function import LinearSimilarity +from .utils import get_very_negative_number + + +class BiDAFEmbedding(HybridBlock): + """BiDAFEmbedding is a class describing embeddings that are separately applied to question + and context of the datasource. Both question and context are passed in two NDArrays: + 1. Matrix of words: batch_size x words_per_question/context + 2. Tensor of characters: batch_size x words_per_question/context x chars_per_word + """ + def __init__(self, batch_size, word_vocab, char_vocab, max_seq_len, + contextual_embedding_nlayers=2, highway_nlayers=2, embedding_size=100, + dropout=0.2, prefix=None, params=None): + super(BiDAFEmbedding, self).__init__(prefix=prefix, params=params) + + self._word_vocab = word_vocab + self._batch_size = batch_size + self._max_seq_len = max_seq_len + self._embedding_size = embedding_size + + with self.name_scope(): + self._char_dense_embedding = nn.Embedding(input_dim=len(char_vocab), + output_dim=8, + weight_initializer=initializer.Uniform(0.001)) + self._dropout = nn.Dropout(rate=dropout) + self._char_conv_embedding = ConvolutionalEncoder( + embed_size=8, + num_filters=(100,), + ngram_filter_sizes=(5,), + num_highway=None, + conv_layer_activation='relu', + output_size=None + ) + + self._word_embedding = nn.Embedding(prefix='predefined_embedding_layer', + input_dim=len(word_vocab), + output_dim=embedding_size) + + self._highway_network = Highway(2 * embedding_size, num_layers=highway_nlayers) + self._contextual_embedding = LSTM(hidden_size=embedding_size, + num_layers=contextual_embedding_nlayers, + bidirectional=True, input_size=2 * embedding_size, + dropout=dropout) + + def init_embeddings(self, grad_req='null'): + """Initialize words embeddings with provided embedding values + + Parameters + ---------- + grad_req: str + How to treat gradients of embedding layer + """ + self._word_embedding.weight.set_data(self._word_vocab.embedding.idx_to_vec) + self._word_embedding.collect_params().setattr('grad_req', grad_req) + + def hybrid_forward(self, F, w, c, *args): # pylint: disable=arguments-differ + word_embedded = self._word_embedding(w) + char_level_data = self._char_dense_embedding(c) + char_level_data = self._dropout(char_level_data) + + # Transpose to put seq_len first axis to iterate over it + char_level_data = F.transpose(char_level_data, axes=(1, 2, 0, 3)) + + def convolute(token_of_all_batches, _): + return self._char_conv_embedding(token_of_all_batches), [] + char_embedded, _ = F.contrib.foreach(convolute, char_level_data, []) + + # Transpose to TNC, to join with character embedding + word_embedded = F.transpose(word_embedded, axes=(1, 0, 2)) + highway_input = F.concat(char_embedded, word_embedded, dim=2) + + def highway(token_of_all_batches, _): + return self._highway_network(token_of_all_batches), [] + + # Pass through highway, shape remains unchanged + highway_output, _ = F.contrib.foreach(highway, highway_input, []) + + # Transpose to TNC - default for LSTM + ce_output = self._contextual_embedding(highway_output) + return ce_output + + +class BiDAFModelingLayer(HybridBlock): + """BiDAFModelingLayer implements modeling layer of BiDAF paper. It is used to scan over context + produced by Attentional Flow Layer via 2 layer bi-LSTM. + + The input data for the forward should be of dimension 8 * hidden_size (default hidden_size + is 100). + + Parameters + ---------- + + input_dim : `int`, default 100 + The number of features in the hidden state h of LSTM + nlayers : `int`, default 2 + Number of recurrent layers. + biflag: `bool`, default True + If `True`, becomes a bidirectional RNN. + dropout: `float`, default 0 + If non-zero, introduces a dropout layer on the outputs of each + RNN layer except the last layer. + prefix : `str` or None + Prefix of this `Block`. + params : `ParameterDict` or `None` + Shared Parameters for this `Block`. + """ + def __init__(self, batch_size, input_dim=100, nlayers=2, biflag=True, + dropout=0.2, input_size=800, prefix=None, params=None): + super(BiDAFModelingLayer, self).__init__(prefix=prefix, params=params) + + self._batch_size = batch_size + + with self.name_scope(): + self._modeling_layer = LSTM(hidden_size=input_dim, num_layers=nlayers, dropout=dropout, + bidirectional=biflag, input_size=input_size) + + def hybrid_forward(self, F, x, *args): # pylint: disable=arguments-differ + out = self._modeling_layer(x) + return out + + +class BiDAFOutputLayer(HybridBlock): + """ + ``BiDAFOutputLayer`` produces the final prediction of an answer. The output is a tuple of + start and end index of token in the paragraph per each batch. + + It accepts 2 inputs: + `x` : the output of Attention layer of shape: + seq_max_length x batch_size x 8 * span_start_input_dim + + `m` : the output of Modeling layer of shape: + seq_max_length x batch_size x 2 * span_start_input_dim + + Parameters + ---------- + batch_size : `int` + Size of a batch + span_start_input_dim : `int`, default 100 + The number of features in the hidden state h of LSTM + nlayers : `int`, default 1 + Number of recurrent layers. + biflag: `bool`, default True + If `True`, becomes a bidirectional RNN. + dropout: `float`, default 0 + If non-zero, introduces a dropout layer on the outputs of each + RNN layer except the last layer. + prefix : `str` or None + Prefix of this `Block`. + params : `ParameterDict` or `None` + Shared Parameters for this `Block`. + """ + def __init__(self, batch_size, span_start_input_dim=100, nlayers=1, biflag=True, + dropout=0.2, prefix=None, params=None): + super(BiDAFOutputLayer, self).__init__(prefix=prefix, params=params) + + self._batch_size = batch_size + + with self.name_scope(): + self._dropout = nn.Dropout(rate=dropout) + self._start_index_combined = nn.Dense(units=1, in_units=8 * span_start_input_dim, + flatten=False) + self._start_index_model = nn.Dense(units=1, in_units=2 * span_start_input_dim, + flatten=False) + self._end_index_lstm = LSTM(hidden_size=span_start_input_dim, + num_layers=nlayers, dropout=dropout, bidirectional=biflag, + input_size=2 * span_start_input_dim) + self._end_index_combined = nn.Dense(units=1, in_units=8 * span_start_input_dim, + flatten=False) + self._end_index_model = nn.Dense(units=1, in_units=2 * span_start_input_dim, + flatten=False) + + def hybrid_forward(self, F, x, m, mask, *args): # pylint: disable=arguments-differ + # setting batch size as the first dimension + x = F.transpose(x, axes=(1, 0, 2)) + + start_index_dense_output = self._start_index_combined(self._dropout(x)) + \ + self._start_index_model(self._dropout( + F.transpose(m, axes=(1, 0, 2)))) + + m2 = self._end_index_lstm(m) + end_index_dense_output = self._end_index_combined(self._dropout(x)) + \ + self._end_index_model(self._dropout(F.transpose(m2, + axes=(1, 0, 2)))) + + start_index_dense_output = F.squeeze(start_index_dense_output) + start_index_dense_output_masked = start_index_dense_output + ((1 - mask) * + get_very_negative_number()) + + end_index_dense_output = F.squeeze(end_index_dense_output) + end_index_dense_output_masked = end_index_dense_output + ((1 - mask) * + get_very_negative_number()) + + return start_index_dense_output_masked, \ + end_index_dense_output_masked + + +class BiDAFModel(HybridBlock): + """Bidirectional attention flow model for Question answering. Implemented according to the + following work: + + @article{DBLP:journals/corr/abs-1804-09541, + author = {Minjoon Seo and + Aniruddha Kembhavi and + Ali Farhadi and + Hannaneh Hajishirzi}, + title = {Bidirectional Attention Flow for Machine Comprehension}, + year = {2016}, + url = {https://arxiv.org/abs/1611.01603} + } + """ + def __init__(self, word_vocab, char_vocab, options, prefix=None, params=None): + super().__init__(prefix=prefix, params=params) + self._options = options + + with self.name_scope(): + self.ctx_embedding = BiDAFEmbedding(options.batch_size, + word_vocab, + char_vocab, + options.ctx_max_len, + options.ctx_embedding_num_layers, + options.highway_num_layers, + options.embedding_size, + dropout=options.dropout, + prefix='context_embedding') + + self.similarity_function = LinearSimilarity(array_1_dim=6 * options.embedding_size, + array_2_dim=1, + combination='x,y,x*y') + + self.matrix_attention = AttentionFlow(self.similarity_function, + options.batch_size, + options.ctx_max_len, + options.q_max_len, + 2 * options.embedding_size) + + # we multiple embedding_size by 2 because we use bidirectional embedding + self.attention_layer = BidirectionalAttentionFlow(options.batch_size, + options.ctx_max_len, + options.q_max_len, + 2 * options.embedding_size) + self.modeling_layer = BiDAFModelingLayer(options.batch_size, + input_dim=options.embedding_size, + nlayers=options.modeling_num_layers, + dropout=options.dropout) + self.output_layer = BiDAFOutputLayer(options.batch_size, + span_start_input_dim=options.embedding_size, + nlayers=options.output_num_layers, + dropout=options.dropout) + + def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False, + force_reinit=False): + super(BiDAFModel, self).initialize(init, ctx, verbose, force_reinit) + self.ctx_embedding.init_embeddings('null' if not self._options.train_unk_token else 'write') + + def hybrid_forward(self, F, qw, cw, qc, cc, *args): # pylint: disable=arguments-differ + ctx_embedding_output = self.ctx_embedding(cw, cc) + q_embedding_output = self.ctx_embedding(qw, qc) + + # attention layer expect batch_size x seq_length x channels + ctx_embedding_output = F.transpose(ctx_embedding_output, axes=(1, 0, 2)) + q_embedding_output = F.transpose(q_embedding_output, axes=(1, 0, 2)) + + # Both masks can be None + q_mask = qw != 0 + ctx_mask = cw != 0 + + passage_question_similarity = self.matrix_attention(ctx_embedding_output, + q_embedding_output) + + passage_question_similarity = passage_question_similarity.reshape( + shape=(self._options.batch_size, + self._options.ctx_max_len, + self._options.q_max_len)) + + attention_layer_output = self.attention_layer(passage_question_similarity, + ctx_embedding_output, + q_embedding_output, + q_mask, + ctx_mask) + + attention_layer_output = F.transpose(attention_layer_output, axes=(1, 0, 2)) + + # modeling layer expects seq_length x batch_size x channels + modeling_layer_output = self.modeling_layer(attention_layer_output) + + output = self.output_layer(attention_layer_output, modeling_layer_output, ctx_mask) + + return output diff --git a/scripts/question_answering/question_id_mapper.py b/scripts/question_answering/question_id_mapper.py new file mode 100644 index 0000000000..0b8307d77f --- /dev/null +++ b/scripts/question_answering/question_id_mapper.py @@ -0,0 +1,63 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Question id mapper to and from int""" + + +class QuestionIdMapper: + """Stores mapping between question id and context of SQuAD dataset + + """ + def __init__(self, dataset): + self._question_id_to_context = {item[1]: item[3] for item in dataset} + self._question_id_to_idx = {item[1]: item[0] for item in dataset} + self._idx_to_question_id = {v: k for k, v in self._question_id_to_idx.items()} + + @property + def question_id_to_context(self): + """Provides question Id to context map + + Returns + ------- + map: Dict + Question Id to Context map + """ + return self._question_id_to_context + + @property + def idx_to_question_id(self): + """Provides record index to question Id map + + Returns + ------- + map: Dict + Record index to question Id map + """ + return self._idx_to_question_id + + @property + def question_id_to_idx(self): + """Provides question Id to record index map + + Returns + ------- + map: Dict + Question Id to record index map + """ + return self._question_id_to_idx diff --git a/scripts/question_answering/similarity_function.py b/scripts/question_answering/similarity_function.py new file mode 100644 index 0000000000..5a6fc7f324 --- /dev/null +++ b/scripts/question_answering/similarity_function.py @@ -0,0 +1,198 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Collection of general purpose similarity functions""" + +import mxnet as mx +from mxnet import gluon, initializer +from mxnet.gluon import nn, Parameter + +from .utils import combine_tensors + + +class SimilarityFunction(gluon.HybridBlock): + """ + A ``SimilarityFunction`` takes a pair of tensors with the same shape, and computes a similarity + function on the vectors in the last dimension. For example, the tensors might both have shape + `(batch_size, sentence_length, embedding_dim)`, and we will compute some function of the two + vectors of length `embedding_dim` for each position `(batch_size, sentence_length)`, returning a + tensor of shape `(batch_size, sentence_length)`. + + The similarity function could be as simple as a dot product, or it could be a more complex, + parameterized function. + """ + default_implementation = 'dot_product' + + def hybrid_forward(self, F, array_1, array_2): # pylint: disable=arguments-differ + # pylint: disable=arguments-differ + """ + Takes two tensors of the same shape, such as ``(batch_size, length_1, length_2, + embedding_dim)``. Computes a (possibly parameterized) similarity on the final dimension + and returns a tensor with one less dimension, such as ``(batch_size, length_1, length_2)``. + """ + raise NotImplementedError + + +class DotProductSimilarity(SimilarityFunction): + """ + This similarity function simply computes the dot product between each pair of vectors, with an + optional scaling to reduce the variance of the output elements. + + Parameters + ---------- + scale_output : ``bool``, optional + If ``True``, we will scale the output by ``F.sqrt(ndarray.shape[-1])``, to reduce the + variance in the result. + """ + def __init__(self, scale_output=False, **kwargs): + super(DotProductSimilarity, self).__init__(**kwargs) + self._scale_output = scale_output + + def hybrid_forward(self, F, array_1, array_2): + result = (array_1 * array_2).sum(axis=-1) + + if self._scale_output: + result *= F.sqrt(array_1.shape[-1]) + + return result + + +class CosineSimilarity(SimilarityFunction): + """ + This similarity function simply computes the cosine similarity between each pair of vectors. + It has no parameters. + """ + + def hybrid_forward(self, F, array_1, array_2): + normalized_array_1 = F.broadcast_div(array_1, F.norm(array_1, axis=-1, keepdims=True)) + normalized_array_2 = F.broadcast_div(array_2, F.norm(array_2, axis=-1, keepdims=True)) + return (normalized_array_1 * normalized_array_2).sum(axis=-1) + + +class BilinearSimilarity(SimilarityFunction): + """ + This similarity function performs a bilinear transformation of the two input vectors. This + function has a matrix of weights ``W`` and a bias ``b``, and the similarity between two vectors + ``x`` and ``y`` is computed as ``x^T W y + b``. + + Parameters + ---------- + array_1_dim : ``int`` + The dimension of the first ndarray, ``x``, described above. This is ``x.shape[-1]`` - the + length of the vector that will go into the similarity computation. We need this so we can + build the weight matrix correctly. + array_2_dim : ``int`` + The dimension of the second ndarray, ``y``, described above. This is ``y.shape[-1]`` - the + length of the vector that will go into the similarity computation. We need this so we can + build the weight matrix correctly. + activation : ``Activation``, optional (default=linear (i.e. no activation)) + An activation function applied after the ``x^T W y + b`` calculation. Default is no + activation. + """ + def __init__(self, + array_1_dim, + array_2_dim, + activation='linear', + **kwargs): + super(BilinearSimilarity, self).__init__(**kwargs) + self._weight_matrix = Parameter(name='weight_matrix', + shape=(array_1_dim, array_2_dim), init=mx.init.Xavier()) + self._bias = Parameter(name='bias', shape=(array_1_dim,), init=mx.init.Zero()) + + if activation == 'linear': + self._activation = None + else: + self._activation = nn.Activation(activation) + + def hybrid_forward(self, F, array_1, array_2): + intermediate = F.broadcast_mull(array_1, self._weight_matrix) + result = F.broadcast_mull(intermediate, array_2).sum(axis=-1) + + if not self._activation: + return result + + return self._activation(result + self._bias) + + +class LinearSimilarity(SimilarityFunction): + """ + This similarity function performs a dot product between a vector of weights and some + combination of the two input vectors, followed by an (optional) activation function. The + combination used is configurable. + + If the two vectors are ``x`` and ``y``, we allow the following kinds of combinations: ``x``, + ``y``, ``x*y``, ``x+y``, ``x-y``, ``x/y``, where each of those binary operations is performed + elementwise. You can list as many combinations as you want, comma separated. For example, you + might give ``x,y,x*y`` as the ``combination`` parameter to this class. The computed similarity + function would then be ``w^T [x; y; x*y] + b``, where ``w`` is a vector of weights, ``b`` is a + bias parameter, and ``[;]`` is vector concatenation. + + Parameters + ---------- + array_1_dim : ``int`` + The dimension of the first tensor, ``x``, described above. This is ``x.size()[-1]`` - the + length of the vector that will go into the similarity computation. We need this so we can + build weight vectors correctly. + array_2_dim : ``int`` + The dimension of the second tensor, ``y``, described above. This is ``y.size()[-1]`` - the + length of the vector that will go into the similarity computation. We need this so we can + build weight vectors correctly. + combination : ``str``, optional (default="x,y") + Described above. + activation : ``Activation``, optional (default=linear (i.e. no activation)) + An activation function applied after the ``w^T * [x;y] + b`` calculation. Default is no + activation. + """ + def __init__(self, + array_1_dim, + array_2_dim, + use_bias=False, + combination='x,y', + activation='linear', + **kwargs): + super(LinearSimilarity, self).__init__(**kwargs) + self.combination = combination + self.use_bias = use_bias + self.array_1_dim = array_1_dim + self.array_2_dim = array_2_dim + + if activation == 'linear': + self._activation = None + else: + self._activation = nn.Activation(activation) + + with self.name_scope(): + self.weight_matrix = self.params.get('weight_matrix', + shape=(array_2_dim, array_1_dim), + init=initializer.Uniform()) + if use_bias: + self.bias = self.params.get('bias', + shape=(array_2_dim,), + init=initializer.Zero()) + + def hybrid_forward(self, F, array_1, array_2, weight_matrix, bias=None): + # pylint: disable=arguments-differ + combined_tensors = combine_tensors(F, self.combination, [array_1, array_2]) + dot_product = F.FullyConnected(combined_tensors, weight_matrix, bias=bias, flatten=False, + no_bias=not self.use_bias, num_hidden=self.array_2_dim) + + if not self._activation: + return dot_product + + return self._activation(dot_product + bias) diff --git a/scripts/question_answering/tokenizer.py b/scripts/question_answering/tokenizer.py new file mode 100644 index 0000000000..72b11d78a6 --- /dev/null +++ b/scripts/question_answering/tokenizer.py @@ -0,0 +1,76 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tokenizer for SQuAD dataset""" + +import re +import nltk + + +class BiDAFTokenizer: + """Tokenizer that is used for preprocessing data for BiDAF model. It applies basic tokenizer + and some extra preprocessing steps making data ready to be used for training BiDAF + """ + def __call__(self, sample, lower_case=False): + """Process single record + + Parameters + ---------- + sample: str + The record to tokenize + + Returns + ------- + ret : list of strs + List of tokens + """ + sample = sample.replace('\n', ' ').replace(u'\u000A', '').replace(u'\u00A0', '') + + tokens = [token.replace('\'\'', '"').replace('``', '"') for token in + nltk.word_tokenize(sample)] + tokens = BiDAFTokenizer._process_tokens(tokens) + tokens = [token for token in tokens if len(token) > 0] + + if lower_case: + tokens = [token.lower() for token in tokens] + + return tokens + + @staticmethod + def _process_tokens(temp_tokens): + """Process tokens by splitting them if split symbol is encountered + + Parameters + ---------- + temp_tokens: list[str] + Tokens to be processed + + Returns + ------- + tokens : list[str] + List of updated tokens + """ + tokens = [] + splitters = ('-', '\u2212', '\u2014', '\u2013', '/', '~', '"', '\'', '\u201C', + '\u2019', '\u201D', '\u2018', '\u00B0') + + for token in temp_tokens: + tokens.extend(re.split('([{}])'.format(''.join(splitters)), token)) + + return tokens diff --git a/scripts/question_answering/train_question_answering.py b/scripts/question_answering/train_question_answering.py new file mode 100644 index 0000000000..c38805ed71 --- /dev/null +++ b/scripts/question_answering/train_question_answering.py @@ -0,0 +1,769 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Main script to train BiDAF model""" + +import argparse +import copy +import math +import multiprocessing +import logging +import os +from os.path import isfile +import pickle +from time import time + +import mxnet as mx +from mxnet import gluon, init, autograd +from mxnet.gluon import Trainer +from mxnet.gluon.data import DataLoader, SimpleDataset, ArrayDataset +from mxnet.gluon.loss import SoftmaxCrossEntropyLoss + +from gluonnlp.data import SQuAD + +from .data_processing import VocabProvider, SQuADTransform +from .utils import PolyakAveraging +from .performance_evaluator import PerformanceEvaluator +from .question_answering import BiDAFModel +from .question_id_mapper import QuestionIdMapper +from .tokenizer import BiDAFTokenizer +from .utils import logging_config + + +def transform_dataset(dataset, vocab_provider, options, enable_filtering=False): + """Get transformed dataset + + Parameters + ---------- + dataset : `Dataset` + Original dataset + vocab_provider : `VocabularyProvider` + Vocabulary provider + options : `Namespace` + Data transformation arguments + enable_filtering : `Bool` + Remove data that doesn't match BiDAF model requirements + + Returns + ------- + data : SimpleDataset + Transformed dataset + """ + tokenizer = vocab_provider.get_tokenizer() + transformer = SQuADTransform(vocab_provider, options.q_max_len, + options.ctx_max_len, options.word_max_len, options.embedding_size) + + i = 0 + transformed_records = [] + long_context = 0 + long_question = 0 + + for i, record in enumerate(dataset): + if enable_filtering: + tokenized_question = tokenizer(record[2], lower_case=True) + + if len(tokenized_question) > options.q_max_len: + long_question += 1 + continue + + transformed_record = transformer(*record) + + # if answer end index is after ctx_max_len token we do not use this record + if enable_filtering and transformed_record[6][0][1] >= options.ctx_max_len: + continue + + transformed_records.append(transformed_record) + + processed_dataset = SimpleDataset(transformed_records) + print('{}/{} records left. Too long context {}, too long query {}'.format( + len(processed_dataset), i + 1, long_context, long_question)) + return processed_dataset + + +def get_record_per_answer_span(processed_dataset, options): + """Each record might have multiple answers and for training purposes it is better to increase + number of records by creating a record per each answer. + + Parameters + ---------- + processed_dataset : `Dataset` + Transformed dataset, ready to be trained on + options : `Namespace` + Command arguments + + Returns + ------- + data : Tuple + A tuple of dataset and dataloader. Each item in dataset is: + (index, question_word_index, context_word_index, question_char_index, context_char_index, + answers) + """ + data_no_label = [] + labels = [] + global_index = 0 + + # copy records to a record per answer + for r in processed_dataset: + # creating a set out of answer_span will deduplicate them + for answer_span in set(r[6]): + # if after all preprocessing the answer is not in the context anymore, + # the item is filtered out + if options.filter_long_context and (answer_span[0] > r[3].size or + answer_span[1] > r[3].size): + continue + # need to remove question id before feeding the data to data loader + # And I also replace index with global_index when unrolling answers + data_no_label.append((global_index, r[2], r[3], r[4], r[5])) + labels.append(mx.nd.array(answer_span)) + global_index += 1 + + loadable_data = ArrayDataset(data_no_label, labels) + dataloader = DataLoader(loadable_data, + batch_size=options.batch_size * len(get_context(options)), + shuffle=True, + last_batch='rollover', + pin_memory=True, + num_workers=(multiprocessing.cpu_count() - + len(get_context(options)) - 2)) + + print('Total records for training: {}'.format(len(labels))) + return loadable_data, dataloader + + +def get_vocabs(vocab_provider, options): + """Get word-level and character-level vocabularies + + Parameters + ---------- + dataset : `SQuAD` + SQuAD dataset to build vocab from + options : `Namespace` + Vocab building arguments + + Returns + ------- + data : Tuple + A tuple of word vocabulary and character vocabulary + """ + word_vocab = vocab_provider.get_word_level_vocab(options.embedding_size) + char_vocab = vocab_provider.get_char_level_vocab() + return word_vocab, char_vocab + + +def get_context(options): + """Return context list to work on + + Parameters + ---------- + options : `Namespace` + Command arguments + + Returns + ------- + ctx : list[Context] + List of contexts + """ + ctx = [] + + if options.gpu is None: + ctx.append(mx.cpu(0)) + print('Use CPU') + else: + indices = options.gpu.split(',') + + for index in indices: + ctx.append(mx.gpu(int(index))) + + return ctx + + +def run_training(net, dataloader, ctx, options): + """Main function to do training of the network + + Parameters + ---------- + net : `Block` + Network to train + dataloader : `DataLoader` + Initialized dataloader + ctx: `Context` + Training context + options : `Namespace` + Training arguments + """ + + hyperparameters = {'learning_rate': options.lr} + + if options.rho: + hyperparameters['rho'] = options.rho + + trainer = Trainer(net.collect_params(), options.optimizer, hyperparameters, + kvstore='device', update_on_kvstore=False) + + if options.resume_training: + path = os.path.join(options.save_dir, + 'trainer_epoch{:d}.params'.format(options.resume_training - 1)) + trainer.load_states(path) + + loss_function = SoftmaxCrossEntropyLoss() + ema = None + + train_start = time() + avg_loss = mx.nd.zeros((1,), ctx=ctx[0]) + iteration = 1 + max_dev_exact = -1 + max_dev_f1 = -1 + max_iteration = -1 + early_stop_tries = 0 + + print('Starting training...') + + for e in range(0 if not options.resume_training else options.resume_training, + options.epochs): + i = 0 + avg_loss *= 0 # Zero average loss of each epoch + records_per_epoch_count = 0 + + for i, (data, label) in enumerate(dataloader): + # start timing for the first batch of epoch + if i == 0: + e_start = time() + + record_index, q_words, ctx_words, q_chars, ctx_chars = data + records_per_epoch_count += record_index.shape[0] + + record_index = gluon.utils.split_and_load(record_index, ctx, even_split=False) + q_words = gluon.utils.split_and_load(q_words, ctx, even_split=False) + ctx_words = gluon.utils.split_and_load(ctx_words, ctx, even_split=False) + q_chars = gluon.utils.split_and_load(q_chars, ctx, even_split=False) + ctx_chars = gluon.utils.split_and_load(ctx_chars, ctx, even_split=False) + label = gluon.utils.split_and_load(label, ctx, even_split=False) + + losses = [] + + for _, qw, cw, qc, cc, l in zip(record_index, q_words, ctx_words, + q_chars, ctx_chars, label): + with autograd.record(): + begin, end = net(qw, cw, qc, cc) + begin_end = l.split(axis=1, num_outputs=2, squeeze_axis=1) + loss = loss_function(begin, begin_end[0]) + \ + loss_function(end, begin_end[1]) + losses.append(loss) + + for loss in losses: + loss.backward() + + if iteration == 1 and options.use_exponential_moving_average: + ema = PolyakAveraging(net.collect_params(), + options.exponential_moving_average_weight_decay) + + if options.resume_training: + path = os.path.join(options.save_dir, 'epoch{:d}.params'.format( + options.resume_training - 1)) + ema.get_params().load(path) + + # in special mode we collect gradients and apply processing only after + # predefined number of grad_req_add_mode which acts like batch_size counter + if options.grad_req_add_mode > 0: + if not iteration % options.grad_req_add_mode != 0 and \ + iteration != len(dataloader): + iteration += 1 + continue + + if options.lr_warmup_steps: + trainer.set_learning_rate(warm_up_steps(iteration, options)) + + execute_trainer_step(net, trainer, ctx, options) + + if ema is not None: + ema.update() + + if e == options.epochs - 1 and \ + options.log_interval > 0 and \ + iteration > 0 and iteration % options.log_interval == 0: + evaluate_options = copy.deepcopy(options) + evaluate_options.epochs = iteration + eval_result = run_evaluate_mode(evaluate_options, net, ema) + + print('Iteration {} evaluation results on dev dataset: {}'.format(iteration, + eval_result)) + if not options.early_stop: + continue + + if eval_result['f1'] > max_dev_f1: + max_dev_f1 = eval_result['f1'] + max_dev_exact = eval_result['exact_match'] + max_iteration = iteration + early_stop_tries = 0 + else: + early_stop_tries += 1 + if early_stop_tries < options.early_stop: + print('Results decreased for {} times'.format(early_stop_tries)) + else: + print('Results decreased for {} times. Stop training. ' + 'Best results are stored at {} params file. F1={}, EM={}' \ + .format(options.early_stop + 1, max_iteration, + max_dev_f1, max_dev_exact)) + break + + for l in losses: + avg_loss += l.mean().as_in_context(avg_loss.context) + + iteration += 1 + + mx.nd.waitall() + + avg_loss /= (i * len(ctx)) + + # block the call here to get correct Time per epoch + avg_loss_scalar = avg_loss.asscalar() + epoch_time = time() - e_start + + print('\tEPOCH {:2}: train loss {:6.4f} | batch {:4} | lr {:5.3f} ' + '| throughtput {:5.3f} of samples/sec | Time per epoch {:5.2f} seconds' + .format(e, avg_loss_scalar, options.batch_size, trainer.learning_rate, + records_per_epoch_count / epoch_time, epoch_time)) + + save_model_parameters(net.collect_params() if ema is None else ema.get_params(), e, options) + save_trainer_parameters(trainer, e, options) + + if options.terminate_training_on_reaching_F1_threshold: + evaluate_options = copy.deepcopy(options) + evaluate_options.epochs = e + eval_result = run_evaluate_mode(evaluate_options, net, ema) + + if eval_result['f1'] >= options.terminate_training_on_reaching_F1_threshold: + print('Finishing training on {} epoch, because dev F1 score is >= required {}. {}' + .format(e, options.terminate_training_on_reaching_F1_threshold, eval_result)) + break + + print('Training time {:6.2f} seconds'.format(time() - train_start)) + + +def warm_up_steps(iteration, options): + """Returns learning rate based on current iteration. Used to implement learning rate warm up + technique + + Parameters + ---------- + iteration : `int` + Number of iteration + options : `Namespace` + Training options + + Returns + ------- + learning_rate : float + Learning rate + """ + if options.resume_training: + return options.lr + + return min(options.lr, options.lr * (math.log(iteration) / math.log(options.lr_warmup_steps))) + + +def get_gradients(model, ctx, options): + """Get gradients and apply gradient decay to all layers if required. + + Parameters + ---------- + model : `BiDAFModel` + Model in training + ctx : `Context` + Training context + options : `Namespace` + Training options + + Returns + ------- + gradients : list + List of gradients + """ + gradients = [] + + for name, parameter in model.collect_params().items(): + if is_fixed_embedding_layer(name) and not options.train_unk_token: + continue + + grad = parameter.grad(ctx) + + if options.weight_decay: + if is_fixed_embedding_layer(name): + grad[0] += options.weight_decay * parameter.data(ctx)[0] + else: + grad += options.weight_decay * parameter.data(ctx) + + gradients.append(grad) + + return gradients + + +def reset_embedding_gradients(model, ctx): + """Gradients for embedding layer doesn't need to be trained. We train only UNK token of + embedding if required. + + Parameters + ---------- + model : `BiDAFModel` + Model in training + ctx : `Context` + Training context + """ + model.ctx_embedding._word_embedding.weight.grad(ctx=ctx)[1:] = 0 + + +def is_fixed_embedding_layer(name): + """Check if this is an embedding layer which parameters are supposed to be fixed + + Parameters + ---------- + name : `str` + Layer name to check + """ + return True if 'predefined_embedding_layer' in name else False + + +def execute_trainer_step(net, trainer, ctx, options): + """Does training step if doesn't need to do gradient clipping or train unknown symbols. + + Parameters + ---------- + net : `Block` + Network to train + trainer : `Trainer` + Trainer + ctx: list + Context list + options: `SimpleNamespace` + Training options + """ + scailing_coeff = len(ctx) * options.batch_size \ + if options.grad_req_add_mode == 0 else options.grad_req_add_mode + + if options.clip or options.train_unk_token: + trainer.allreduce_grads() + gradients = get_gradients(net, ctx[0], options) + + if options.clip: + gluon.utils.clip_global_norm(gradients, options.clip) + + if options.train_unk_token: + reset_embedding_gradients(net, ctx[0]) + + if len(ctx) > 1: + # in multi gpu mode we propagate new gradients to the rest of gpus + for _, parameter in net.collect_params().items(): + grads = parameter.list_grad() + source = grads[0] + destination = grads[1:] + + for dest in destination: + source.copyto(dest) + + trainer.update(scailing_coeff) + else: + trainer.step(scailing_coeff) + + +def save_model_parameters(params, epoch, options): + """Save parameters of the trained model + + Parameters + ---------- + params : `gluon.ParameterDict` + Model with trained parameters + epoch : `int` + Number of epoch + options : `Namespace` + Saving arguments + """ + if not os.path.exists(options.save_dir): + os.mkdir(options.save_dir) + + save_path = os.path.join(options.save_dir, 'epoch{:d}.params'.format(epoch)) + params.save(save_path) + + +def save_trainer_parameters(trainer, epoch, options): + """Save exponentially averaged parameters of the trained model + + Parameters + ---------- + trainer : `Trainer` + Trainer + epoch : `int` + Number of epoch + options : `Namespace` + Saving arguments + """ + if trainer is None: + return + + if not os.path.exists(options.save_dir): + os.mkdir(options.save_dir) + + save_path = os.path.join(options.save_dir, 'trainer_epoch{:d}.params'.format(epoch)) + trainer.save_states(save_path) + + +def save_transformed_dataset(dataset, path): + """Save processed dataset into a file. + + Parameters + ---------- + dataset : `Dataset` + Dataset to save + path : `str` + Saving path + """ + pickle.dump(dataset, open(path, 'wb')) + + +def load_transformed_dataset(path): + """Loads already preprocessed dataset from disk + + Parameters + ---------- + path : `str` + Loading path + + Returns + ------- + processed_dataset : SimpleDataset + Transformed dataset + """ + processed_dataset = pickle.load(open(path, 'rb')) + return processed_dataset + + +def run_preprocess_mode(options): + """Run program in data preprocessing mode + + Parameters + ---------- + options : `Namespace` + Data preprocessing arguments + """ + # we use both datasets to create proper vocab + dataset_train = SQuAD(segment='train') + dataset_dev = SQuAD(segment='dev') + + vocab_provider = VocabProvider([dataset_train, dataset_dev], options) + transformed_dataset = transform_dataset(dataset_train, vocab_provider, options=options, + enable_filtering=True) + save_transformed_dataset(transformed_dataset, options.preprocessed_dataset_path) + + if options.preprocessed_val_dataset_path: + transformed_dataset = transform_dataset(dataset_dev, vocab_provider, options=options) + save_transformed_dataset(transformed_dataset, options.preprocessed_val_dataset_path) + + +def run_training_mode(options): + """Run program in data training mode + + Parameters + ---------- + options : `Namespace` + Model training parameters + """ + dataset = SQuAD(segment='train') + dataset_val = SQuAD(segment='dev') + vocab_provider = VocabProvider([dataset, dataset_val], options) + + if options.preprocessed_dataset_path and isfile(options.preprocessed_dataset_path): + transformed_dataset = load_transformed_dataset(options.preprocessed_dataset_path) + else: + transformed_dataset = transform_dataset(dataset, vocab_provider, options=options, + enable_filtering=True) + save_transformed_dataset(transformed_dataset, options.preprocessed_dataset_path) + + _, train_dataloader = get_record_per_answer_span(transformed_dataset, options) + word_vocab, char_vocab = get_vocabs(vocab_provider, options=options) + ctx = get_context(options) + + net = BiDAFModel(word_vocab, char_vocab, options, prefix='bidaf') + net.initialize(init.Xavier(), ctx=ctx) + net.hybridize(static_alloc=True) + + if options.grad_req_add_mode: + net.collect_params().setattr('grad_req', 'add') + + if options.resume_training: + print('Resuming training from {} epoch'.format(options.resume_training)) + params_path = os.path.join(options.save_dir, + 'epoch{:d}.params'.format(int(options.resume_training) - 1)) + net.load_parameters(params_path, ctx) + + run_training(net, train_dataloader, ctx, options=options) + + +def run_evaluate_mode(options, existing_net=None, existing_ema=None): + """Run program in evaluating mode + + Parameters + ---------- + existing_net : `Block` + Trained existing network + existing_net : `PolyakAveraging` + Averaged parameters of the network + options : `Namespace` + Model evaluation arguments + + Returns + ------- + result : dict + Dictionary with exact_match and F1 scores + """ + train_dataset = SQuAD(segment='train') + dataset = SQuAD(segment='dev') + + vocab_provider = VocabProvider([train_dataset, dataset], options) + mapper = QuestionIdMapper(dataset) + + if options.preprocessed_val_dataset_path and isfile(options.preprocessed_val_dataset_path): + transformed_dataset = load_transformed_dataset(options.preprocessed_val_dataset_path) + else: + transformed_dataset = transform_dataset(dataset, vocab_provider, options=options) + save_transformed_dataset(transformed_dataset, options.preprocessed_val_dataset_path) + + word_vocab, char_vocab = get_vocabs(vocab_provider, options=options) + ctx = get_context(options) + + evaluator = PerformanceEvaluator(BiDAFTokenizer(), transformed_dataset, + dataset._read_data(), mapper) + + net = BiDAFModel(word_vocab, char_vocab, options, prefix='bidaf') + + if existing_ema is not None: + save_model_parameters(existing_ema.get_params(), options.epochs, options) + + elif existing_net is not None: + save_model_parameters(existing_net.collect_params(), options.epochs, options) + + params_path = os.path.join(options.save_dir, + 'epoch{:d}.params'.format(int(options.epochs) - 1)) + + net.collect_params().load(params_path, ctx=ctx) + net.hybridize(static_alloc=True) + return evaluator.evaluate_performance(net, ctx, options) + + +def get_args(): + """Get console arguments + """ + parser = argparse.ArgumentParser(description='Question Answering example using BiDAF & SQuAD') + parser.add_argument('--preprocess', default=False, action='store_true', + help='Preprocess dataset') + parser.add_argument('--train', default=False, action='store_true', + help='Run training') + parser.add_argument('--evaluate', default=False, action='store_true', + help='Run evaluation on dev dataset') + parser.add_argument('--preprocessed_dataset_path', type=str, + default='preprocessed_dataset.p', help='Path to preprocessed dataset') + parser.add_argument('--preprocessed_val_dataset_path', type=str, + default='preprocessed_val_dataset.p', + help='Path to preprocessed validation dataset') + parser.add_argument('--epochs', type=int, default=12, help='Upper epoch limit') + parser.add_argument('--embedding_size', type=int, default=100, + help='Dimension of the word embedding') + parser.add_argument('--dropout', type=float, default=0.2, + help='dropout applied to layers (0 = no dropout)') + parser.add_argument('--ctx_embedding_num_layers', type=int, default=2, + help='Number of layers in Contextual embedding layer of BiDAF') + parser.add_argument('--highway_num_layers', type=int, default=2, + help='Number of layers in Highway layer of BiDAF') + parser.add_argument('--modeling_num_layers', type=int, default=2, + help='Number of layers in Modeling layer of BiDAF') + parser.add_argument('--output_num_layers', type=int, default=1, + help='Number of layers in Output layer of BiDAF') + parser.add_argument('--batch_size', type=int, default=60, help='Batch size') + parser.add_argument('--ctx_max_len', type=int, default=400, help='Maximum length of a context') + parser.add_argument('--q_max_len', type=int, default=30, help='Maximum length of a question') + parser.add_argument('--word_max_len', type=int, default=16, help='Maximum characters in a word') + parser.add_argument('--answer_max_len', type=int, default=30, help='Maximum tokens in answer') + parser.add_argument('--optimizer', type=str, default='adadelta', help='optimization algorithm') + parser.add_argument('--lr', type=float, default=0.5, help='Initial learning rate') + parser.add_argument('--rho', type=float, default=0.9, + help='Adadelta decay rate for both squared gradients and delta.') + parser.add_argument('--lr_warmup_steps', type=int, default=0, + help='Defines how many iterations to spend on warming up learning rate') + parser.add_argument('--clip', type=float, default=0, help='gradient clipping') + parser.add_argument('--weight_decay', type=float, default=0, + help='Weight decay for parameter updates') + parser.add_argument('--log_interval', type=int, default=100, metavar='N', + help='Report interval applied to last epoch only') + parser.add_argument('--early_stop', type=int, default=9, + help='Apply early stopping for the last epoch. Stop after # of consequent ' + '# of times F1 is lower than max. Should be used with log_interval') + parser.add_argument('--resume_training', type=int, default=0, + help='Resume training from this epoch number') + parser.add_argument('--terminate_training_on_reaching_F1_threshold', type=float, default=0, + help='Some tasks, like DAWNBenchmark requires to minimize training time ' + 'while reaching a particular F1 metric. This parameter controls if ' + 'training should be terminated as soon as F1 is reached to minimize ' + 'training time and cost. It would force to do evaluation every epoch.') + parser.add_argument('--save_dir', type=str, default='out_dir', + help='directory path to save the final model and training log') + parser.add_argument('--word_vocab_path', type=str, default=None, + help='Path to preprocessed word-level vocabulary') + parser.add_argument('--char_vocab_path', type=str, default=None, + help='Path to preprocessed character-level vocabulary') + parser.add_argument('--gpu', type=str, default=None, + help='Coma-separated ids of the gpu to use. Empty means to use cpu.') + parser.add_argument('--train_unk_token', default=False, action='store_true', + help='Should train unknown token of embedding') + parser.add_argument('--filter_long_context', default=True, action='store_false', + help='Filter contexts if the answer is after ctx_max_len') + parser.add_argument('--save_prediction_path', type=str, default='', + help='Path to save predictions') + parser.add_argument('--use_exponential_moving_average', default=True, action='store_false', + help='Should averaged copy of parameters been stored and used ' + 'during evaluation.') + parser.add_argument('--exponential_moving_average_weight_decay', type=float, default=0.999, + help='Weight decay used in exponential moving average') + parser.add_argument('--grad_req_add_mode', type=int, default=0, + help='Enable rolling gradient mode, where batch size is always 1 and ' + 'gradients are accumulated using single GPU') + + options = parser.parse_args() + return options + + +if __name__ == '__main__': + args = get_args() + args.batch_size = int(args.batch_size / len(get_context(args))) + print(args) + logging_config(args.save_dir) + + if args.preprocess: + if not args.preprocessed_dataset_path: + logging.error('Preprocessed_data_path attribute is not provided') + exit(1) + + print('Running in preprocessing mode') + run_preprocess_mode(args) + + if args.train: + print('Running in training mode') + run_training_mode(args) + + if args.evaluate: + print('Running in evaluation mode') + result = run_evaluate_mode(args) + print('Evaluation results on dev dataset: {}'.format(result)) diff --git a/scripts/question_answering/utils.py b/scripts/question_answering/utils.py new file mode 100644 index 0000000000..fc1ec73ad0 --- /dev/null +++ b/scripts/question_answering/utils.py @@ -0,0 +1,379 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Set of utility methods for question answering models""" + +import inspect +import logging +import os + +import mxnet as mx +from mxnet import nd, gluon + + +def logging_config(folder=None, name=None, level=logging.DEBUG, console_level=logging.INFO, + no_console=False): + """ Config the logging. + + Parameters + ---------- + folder : str or None + name : str or None + level : int + console_level + no_console: bool + Whether to disable the console log + Returns + ------- + folder : str + Folder that the logging file will be saved into. + """ + if name is None: + name = inspect.stack()[1][1].split('.')[0] + + if folder is None: + folder = os.path.join(os.getcwd(), name) + + if not os.path.exists(folder): + os.makedirs(folder) + + # Remove all the current handlers + for handler in logging.root.handlers: + logging.root.removeHandler(handler) + + logging.root.handlers = [] + logpath = os.path.join(folder, name + '.log') + print('All Logs will be saved to {}'.format(logpath)) + logging.root.setLevel(level) + formatter = logging.Formatter('%(asctime)s - %(name)s - %(message)s') + logfile = logging.FileHandler(logpath) + logfile.setLevel(level) + logfile.setFormatter(formatter) + logging.root.addHandler(logfile) + + if not no_console: + # Initialze the console logging + logconsole = logging.StreamHandler() + logconsole.setLevel(console_level) + logconsole.setFormatter(formatter) + logging.root.addHandler(logconsole) + + return folder + + +def get_very_negative_number(): + return -1e30 + + +def extend_to_batch_size(batch_size, prototype, fill_value=0): + """Provides NDArray, which consist of prototype NDArray and NDArray filled with fill_value to + batch_size number of items. New NDArray appended to batch dimension (dim=0). + + Parameters + ---------- + batch_size: ``int`` + Expected value for batch_size dimension (dim=0). + prototype: ``NDArray`` + NDArray to be extended of shape (batch_size, ...) + fill_value: ``float`` + Value to use for filling + """ + if batch_size == prototype.shape[0]: + return prototype + + new_shape = (batch_size - prototype.shape[0],) + prototype.shape[1:] + dummy_elements = nd.full(val=fill_value, shape=new_shape, dtype=prototype.dtype, + ctx=prototype.context) + return nd.concat(prototype, dummy_elements, dim=0) + + +def get_combined_dim(combination, tensor_dims): + """ + For use with :func:`combine_tensors`. This function computes the resultant dimension when + calling ``combine_tensors(combination, tensors)``, when the tensor dimension is known. This is + necessary for knowing the sizes of weight matrices when building models that use + ``combine_tensors``. + + Parameters + ---------- + combination : ``str`` + A comma-separated list of combination pieces, like ``"1,2,1*2"``, specified identically to + ``combination`` in :func:`combine_tensors`. + tensor_dims : ``List[int]`` + A list of tensor dimensions, where each dimension is from the `last axis` of the tensors + that will be input to :func:`combine_tensors`. + """ + combination = combination.replace('x', '1').replace('y', '2') + return sum([_get_combination_dim(piece, tensor_dims) for piece in combination.split(',')]) + + +def _get_combination_dim(combination, tensor_dims): + if combination.isdigit(): + index = int(combination) - 1 + return tensor_dims[index] + else: + first_tensor_dim = _get_combination_dim(combination[0], tensor_dims) + return first_tensor_dim + + +def _get_combination(combination, tensors): + if combination.isdigit(): + index = int(combination) - 1 + return tensors[index] + else: + first_tensor = _get_combination(combination[0], tensors) + second_tensor = _get_combination(combination[2], tensors) + operation = combination[1] + if operation == '*': + return first_tensor * second_tensor + elif operation == '/': + return first_tensor / second_tensor + elif operation == '+': + return first_tensor + second_tensor + elif operation == '-': + return first_tensor - second_tensor + else: + raise NotImplementedError + + +def combine_tensors(F, combination, tensors): + """ + Combines a list of tensors using element-wise operations and concatenation, specified by a + ``combination`` string. The string refers to (1-indexed) positions in the input tensor list, + and looks like ``"1,2,1+2,3-1"``. + + We allow the following kinds of combinations: ``x``, ``x*y``, ``x+y``, ``x-y``, and ``x/y``, + where ``x`` and ``y`` are positive integers less than or equal to ``len(tensors)``. Each of + the binary operations is performed elementwise. You can give as many combinations as you want + in the ``combination`` string. For example, for the input string ``"1,2,1*2"``, the result + would be ``[1;2;1*2]``, as you would expect, where ``[;]`` is concatenation along the last + dimension. + + If you have a fixed, known way to combine tensors that you use in a model, you should probably + just use something like ``ndarray.cat([x_tensor, y_tensor, x_tensor * y_tensor])``. This + function adds some complexity that is only necessary if you want the specific combination used + to be `configurable`. + + If you want to do any element-wise operations, the tensors involved in each element-wise + operation must have the same shape. + + This function also accepts ``x`` and ``y`` in place of ``1`` and ``2`` in the combination + string. + """ + combination = combination.replace('x', '1').replace('y', '2') + to_concatenate = [_get_combination(piece, tensors) for piece in combination.split(',')] + return F.concat(*to_concatenate, dim=-1) + + +def masked_softmax(F, vector, mask, epsilon): + """ + ``nd.softmax(vector)`` does not work if some elements of ``vector`` should be + masked. This performs a softmax on just the non-masked portions of ``vector``. Passing + ``None`` in for the mask is also acceptable; you'll just get a regular softmax. + + We assume that both ``vector`` and ``mask`` (if given) have shape ``(batch_size, vector_dim)``. + + In the case that the input vector is completely masked, this function returns an array + of ``0.0``. This behavior may cause ``NaN`` if this is used as the last layer of a model + that uses categorical cross-entropy loss. + """ + if mask is None: + result = F.softmax(vector, axis=-1) + else: + # To limit numerical errors from large vector elements outside the mask, we zero these out. + result = F.softmax(vector * mask, axis=-1) + result = result * mask + result = F.broadcast_div(result, (result.sum(axis=1, keepdims=True) + epsilon)) + return result + + +def masked_log_softmax(vector, mask): + """ + ``nd.log_softmax(vector)`` does not work if some elements of ``vector`` should be + masked. This performs a log_softmax on just the non-masked portions of ``vector``. Passing + ``None`` in for the mask is also acceptable; you'll just get a regular log_softmax. + + We assume that both ``vector`` and ``mask`` (if given) have shape ``(batch_size, vector_dim)``. + + In the case that the input vector is completely masked, the return value of this function is + arbitrary, but not ``nan``. You should be masking the result of whatever computation comes out + of this in that case, anyway, so the specific values returned shouldn't matter. Also, the way + that we deal with this case relies on having single-precision floats; mixing half-precision + floats with fully-masked vectors will likely give you ``nans``. + + If your logits are all extremely negative (i.e., the max value in your logit vector is -50 or + lower), the way we handle masking here could mess you up. But if you've got logit values that + extreme, you've got bigger problems than this. + """ + if mask is not None: + # vector + mask.log() is an easy way to zero out masked elements in logspace, but it + # results in nans when the whole vector is masked. We need a very small value instead of a + # zero in the mask for these cases. log(1 + 1e-45) is still basically 0, so we can safely + # just add 1e-45 before calling mask.log(). We use 1e-45 because 1e-46 is so small it + # becomes 0 - this is just the smallest value we can actually use. + vector = vector + (mask + 1e-45).log() + return nd.log_softmax(vector, axis=1) + + +def _last_dimension_applicator(F, + function_to_apply, + tensor, + mask, + tensor_shape, + mask_shape, + **kwargs): + """ + Takes a tensor with 3 or more dimensions and applies a function over the last dimension. We + assume the tensor has shape ``(batch_size, ..., sequence_length)`` and that the mask (if given) + has shape ``(batch_size, sequence_length)``. We first unsqueeze and expand the mask so that it + has the same shape as the tensor, then flatten them both to be 2D, pass them through + the function and put the tensor back in its original shape. + """ + reshaped_tensor = tensor.reshape(shape=(-1, tensor_shape[-1])) + + if mask is not None: + shape_difference = len(tensor_shape) - len(mask_shape) + for _ in range(0, shape_difference): + mask = mask.expand_dims(1) + mask = mask.broadcast_to(shape=tensor_shape) + mask = mask.reshape(shape=(-1, mask_shape[-1])) + reshaped_result = function_to_apply(F, reshaped_tensor, mask, **kwargs) + return reshaped_result.reshape(shape=tensor_shape) + + +def last_dim_softmax(F, tensor, mask, tensor_shape, mask_shape, epsilon): + """ + Takes a tensor with 3 or more dimensions and does a masked softmax over the last dimension. We + assume the tensor has shape ``(batch_size, ..., sequence_length)`` and that the mask (if given) + has shape ``(batch_size, sequence_length)``. + """ + return _last_dimension_applicator(F, masked_softmax, tensor, mask, tensor_shape, mask_shape, + epsilon=epsilon) + + +def last_dim_log_softmax(F, tensor, mask, tensor_shape, mask_shape): + """ + Takes a tensor with 3 or more dimensions and does a masked log softmax over the last dimension. + We assume the tensor has shape ``(batch_size, ..., sequence_length)`` and that the mask + (if given) has shape ``(batch_size, sequence_length)``. + """ + return _last_dimension_applicator(F, masked_log_softmax, tensor, mask, tensor_shape, mask_shape) + + +def weighted_sum(F, matrix, attention, matrix_shape, attention_shape): + """ + Takes a matrix of vectors and a set of weights over the rows in the matrix (which we call an + "attention" vector), and returns a weighted sum of the rows in the matrix. This is the typical + computation performed after an attention mechanism. + + Note that while we call this a "matrix" of vectors and an attention "vector", we also handle + higher-order tensors. We always sum over the second-to-last dimension of the "matrix", and we + assume that all dimensions in the "matrix" prior to the last dimension are matched in the + "vector". Non-matched dimensions in the "vector" must be `directly after the batch dimension`. + + For example, say I have a "matrix" with dimensions ``(batch_size, num_queries, num_words, + embedding_dim)``. The attention "vector" then must have at least those dimensions, and could + have more. Both: + + - ``(batch_size, num_queries, num_words)`` (distribution over words for each query) + - ``(batch_size, num_documents, num_queries, num_words)`` (distribution over words in a + query for each document) + + are valid input "vectors", producing tensors of shape: + ``(batch_size, num_queries, embedding_dim)`` and + ``(batch_size, num_documents, num_queries, embedding_dim)`` respectively. + """ + + if len(attention_shape) == 2 and len(matrix_shape) == 3: + return F.squeeze(F.batch_dot(attention.expand_dims(1), matrix), axis=1) + if len(attention_shape) == 3 and len(matrix_shape) == 3: + return F.batch_dot(attention, matrix) + if len(matrix_shape) - 1 < len(attention_shape): + expanded_size = list(matrix_shape) + for i in range(len(attention_shape) - len(matrix_shape) + 1): + matrix = matrix.expand_dims(1) + expanded_size.insert(i + 1, attention.shape[i + 1]) + matrix = matrix.broadcast_to(*expanded_size) + intermediate = attention.expand_dims(-1).broadcast_to(matrix_shape) * matrix + return intermediate.sum(axis=-2) + + +def replace_masked_values(F, tensor, mask, replace_with): + """ + Replaces all masked values in ``tensor`` with ``replace_with``. ``mask`` must be broadcastable + to the same shape as ``tensor``. We require that ``tensor.dim() == mask.dim()``, as otherwise we + won't know which dimensions of the mask to unsqueeze. + """ + # We'll build a tensor of the same shape as `tensor`, zero out masked values, then add back in + # the `replace_with` value. + one_minus_mask = 1.0 - mask + values_to_add = replace_with * one_minus_mask + return F.broadcast_add(F.broadcast_mul(tensor, mask), values_to_add) + + +class PolyakAveraging: + """Class to do Polyak averaging based on this paper + http://www.meyn.ece.ufl.edu/archive/spm_files/Courses/ECE555-2011/555media/poljud92.pdf""" + + def __init__(self, params, decay): + self._params = params + self._decay = decay + + self._polyak_params_dict = gluon.ParameterDict() + + for param in self._params.values(): + polyak_param = self._polyak_params_dict.get(param.name, shape=param.shape) + polyak_param.initialize(mx.init.Constant(self._param_data_to_cpu(param)), ctx=mx.cpu()) + + def update(self): + """ + Updates currently held saved parameters with current state of network. + + All calculations for this average occur on the cpu context. + """ + for param in self._params.values(): + polyak_param = self._polyak_params_dict.get(param.name) + polyak_param.set_data( + (1 - self._decay) * self._param_data_to_cpu(param) + + self._decay * polyak_param.data(mx.cpu())) + + def get_params(self): + """ Provides averaged parameters + + Returns + ------- + gluon.ParameterDict + Averaged parameters + """ + return self._polyak_params_dict + + def _param_data_to_cpu(self, param): + """Returns a copy (on CPU context) of the data held in some context of given parameter. + + Parameters + ---------- + param: gluon.Parameter + Parameter's whose data needs to be copied. + + Returns + ------- + NDArray + Copy of data on CPU context. + """ + return param.list_data()[0].copyto(mx.cpu()) diff --git a/scripts/tests/test_question_answering.py b/scripts/tests/test_question_answering.py index 729fdf6e9f..c74ce8dab1 100644 --- a/scripts/tests/test_question_answering.py +++ b/scripts/tests/test_question_answering.py @@ -17,22 +17,42 @@ # specific language governing permissions and limitations # under the License. import os + import pytest +import mxnet as mx +from mxnet import init, nd, autograd, gluon +from mxnet.gluon import Trainer, nn from mxnet.gluon.data import DataLoader, SimpleDataset +from mxnet.gluon.loss import SoftmaxCrossEntropyLoss +from types import SimpleNamespace +import gluonnlp as nlp from gluonnlp.data import SQuAD -from ..question_answering.data_processing import SQuADTransform, VocabProvider +from scripts.question_answering.attention_flow import AttentionFlow +from scripts.question_answering.bidaf import BidirectionalAttentionFlow +from scripts.question_answering.data_processing import SQuADTransform, VocabProvider +from scripts.question_answering.utils import PolyakAveraging +from scripts.question_answering.question_answering import * +from scripts.question_answering.similarity_function import LinearSimilarity +from scripts.question_answering.tokenizer import BiDAFTokenizer +from scripts.question_answering.train_question_answering import get_record_per_answer_span + + +batch_size = 5 question_max_length = 30 -context_max_length = 256 +context_max_length = 400 +max_chars_per_word = 16 +embedding_size = 100 @pytest.mark.serial def test_transform_to_nd_array(): dataset = SQuAD(segment='dev', root='tests/data/squad') - vocab_provider = VocabProvider(dataset) - transformer = SQuADTransform(vocab_provider, question_max_length, context_max_length) + vocab_provider = VocabProvider([dataset], get_args(batch_size)) + transformer = SQuADTransform(vocab_provider, question_max_length, + context_max_length, max_chars_per_word, embedding_size) record = dataset[0] transformed_record = transformer(*record) @@ -43,8 +63,9 @@ def test_transform_to_nd_array(): @pytest.mark.serial def test_data_loader_able_to_read(): dataset = SQuAD(segment='dev', root='tests/data/squad') - vocab_provider = VocabProvider(dataset) - transformer = SQuADTransform(vocab_provider, question_max_length, context_max_length) + vocab_provider = VocabProvider([dataset], get_args(batch_size)) + transformer = SQuADTransform(vocab_provider, question_max_length, + context_max_length, max_chars_per_word, embedding_size) record = dataset[0] processed_dataset = SimpleDataset([transformer(*record)]) @@ -65,7 +86,328 @@ def test_data_loader_able_to_read(): @pytest.mark.serial def test_load_vocabs(): dataset = SQuAD(segment='dev', root='tests/data/squad') - vocab_provider = VocabProvider(dataset) + vocab_provider = VocabProvider([dataset], get_args(batch_size)) - assert vocab_provider.get_word_level_vocab() is not None + assert vocab_provider.get_word_level_vocab(embedding_size) is not None assert vocab_provider.get_char_level_vocab() is not None + + +def test_bidaf_embedding(): + dataset = SQuAD(segment='dev', root='tests/data/squad') + vocab_provider = VocabProvider([dataset], get_args(batch_size)) + transformer = SQuADTransform(vocab_provider, question_max_length, + context_max_length, max_chars_per_word, embedding_size) + + # for performance reason, process only batch_size # of records + processed_dataset = SimpleDataset([transformer(*record) for i, record in enumerate(dataset) + if i < batch_size]) + + # need to remove question id before feeding the data to data loader + loadable_data, dataloader = get_record_per_answer_span(processed_dataset, get_args(batch_size)) + + word_vocab = vocab_provider.get_word_level_vocab(embedding_size) + word_vocab.set_embedding(nlp.embedding.create('glove', source='glove.6B.100d')) + char_vocab = vocab_provider.get_char_level_vocab() + + embedding = BiDAFEmbedding(word_vocab=word_vocab, + char_vocab=char_vocab, + batch_size=batch_size, + max_seq_len=question_max_length) + embedding.initialize(init.Xavier(magnitude=2.24), ctx=mx.cpu_pinned()) + embedding.hybridize(static_alloc=True) + + trainer = Trainer(embedding.collect_params(), "sgd", {"learning_rate": 0.1}) + + for i, (data, label) in enumerate(dataloader): + with autograd.record(): + record_index, q_words, ctx_words, q_chars, ctx_chars = data + # passing only question_words_nd and question_chars_nd batch + out = embedding(q_words, q_chars) + assert out is not None + + out.backward() + trainer.step(batch_size) + break + + +def test_attention_layer(): + ctx_fake_data = nd.random.uniform(shape=(batch_size, context_max_length, 2 * embedding_size)) + + q_fake_data = nd.random.uniform(shape=(batch_size, question_max_length, 2 * embedding_size)) + + ctx_fake_mask = nd.ones(shape=(batch_size, context_max_length)) + q_fake_mask = nd.ones(shape=(batch_size, question_max_length)) + + matrix_attention = AttentionFlow(LinearSimilarity(array_1_dim=6 * embedding_size, + array_2_dim=1, + combination="x,y,x*y"), + batch_size, + context_max_length, + question_max_length, + 2 * embedding_size) + + layer = BidirectionalAttentionFlow(batch_size, + context_max_length, + question_max_length, + 2 * embedding_size) + + matrix_attention.initialize() + layer.initialize() + + with autograd.record(): + passage_question_similarity = matrix_attention(ctx_fake_data, q_fake_data).reshape( + shape=(batch_size, context_max_length, question_max_length)) + + output = layer(passage_question_similarity, ctx_fake_data, q_fake_data, + q_fake_mask, ctx_fake_mask) + + assert output.shape == (batch_size, context_max_length, 8 * embedding_size) + + +def test_modeling_layer(): + # The modeling layer receive input in a shape of batch_size x T x 8d + # T is the sequence length of context which is context_max_length + # d is the size of embedding, which is embedding_size + fake_data = nd.random.uniform(shape=(batch_size, context_max_length, 8 * embedding_size)) + # We assume that attention is already return data in TNC format + attention_output = nd.transpose(fake_data, axes=(1, 0, 2)) + + layer = BiDAFModelingLayer(batch_size) + layer.initialize() + layer.hybridize(static_alloc=True) + + trainer = Trainer(layer.collect_params(), "sgd", {"learning_rate": "0.1"}) + + with autograd.record(): + output = layer(attention_output) + + output.backward() + # According to the paper, the output should be 2d x T + assert output.shape == (context_max_length, batch_size, 2 * embedding_size) + + +def test_output_layer(): + # The output layer receive 2 inputs: the output of Modeling layer (context_max_length, + # batch_size, 2 * embedding_size) and the output of Attention flow layer + # (batch_size, context_max_length, 8 * embedding_size) + + # The modeling layer returns data in TNC format + modeling_output = nd.random.uniform(shape=(context_max_length, batch_size, 2 * embedding_size)) + # The layer assumes that attention is already return data in TNC format + attention_output = nd.random.uniform(shape=(context_max_length, batch_size, 8 * embedding_size)) + ctx_mask = nd.ones(shape=(batch_size, context_max_length)) + + layer = BiDAFOutputLayer(batch_size) + # The model doesn't need to know the hidden states, so I don't hold variables for the states + layer.initialize() + layer.hybridize(static_alloc=True) + + with autograd.record(): + output = layer(attention_output, modeling_output, ctx_mask) + + # We expect final numbers as batch_size x 2 (first start index, second end index) + assert output[0].shape == (batch_size, 400) and \ + output[1].shape == (batch_size, 400) + + +def test_bidaf_model(): + options = get_args(batch_size) + ctx = [mx.cpu(0)] + + dataset = SQuAD(segment='dev', root='tests/data/squad') + vocab_provider = VocabProvider([dataset], options) + transformer = SQuADTransform(vocab_provider, question_max_length, + context_max_length, max_chars_per_word, embedding_size) + + # for performance reason, process only batch_size # of records + processed_dataset = SimpleDataset([transformer(*record) for i, record in enumerate(dataset) + if i < options.batch_size * len(ctx)]) + + # need to remove question id before feeding the data to data loader + train_dataset, train_dataloader = get_record_per_answer_span(processed_dataset, options) + + word_vocab = vocab_provider.get_word_level_vocab(embedding_size) + word_vocab.set_embedding(nlp.embedding.create('glove', source='glove.6B.100d')) + char_vocab = vocab_provider.get_char_level_vocab() + + net = BiDAFModel(word_vocab=word_vocab, + char_vocab=char_vocab, + options=options) + + net.initialize(init.Xavier(magnitude=2.24)) + net.hybridize(static_alloc=True) + + loss_function = SoftmaxCrossEntropyLoss() + trainer = Trainer(net.collect_params(), "adadelta", {"learning_rate": 0.5}) + + for i, (data, label) in enumerate(train_dataloader): + record_index, q_words, ctx_words, q_chars, ctx_chars = data + + record_index = gluon.utils.split_and_load(record_index, ctx, even_split=False) + q_words = gluon.utils.split_and_load(q_words, ctx, even_split=False) + ctx_words = gluon.utils.split_and_load(ctx_words, ctx, even_split=False) + q_chars = gluon.utils.split_and_load(q_chars, ctx, even_split=False) + ctx_chars = gluon.utils.split_and_load(ctx_chars, ctx, even_split=False) + label = gluon.utils.split_and_load(label, ctx, even_split=False) + + losses = [] + + for ri, qw, cw, qc, cc, l in zip(record_index, q_words, ctx_words, + q_chars, ctx_chars, label): + with autograd.record(): + begin, end = net(qw, cw, qc, cc) + begin_end = l.split(axis=1, num_outputs=2, squeeze_axis=1) + loss = loss_function(begin, begin_end[0]) + \ + loss_function(end, begin_end[1]) + losses.append(loss) + + for loss in losses: + loss.backward() + + trainer.step(options.batch_size) + break + + +def test_get_answer_spans_exact_match(): + tokenizer = BiDAFTokenizer() + + context = "to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary." + context_tokens = tokenizer(context) + + answer_start_index = 3 + answer = "Saint Bernadette Soubirous" + + result = SQuADTransform._get_answer_spans(context, context_tokens, + [answer], [answer_start_index]) + + assert result == [(1, 3)] + + +def test_get_answer_spans_partial_match(): + tokenizer = BiDAFTokenizer() + + context = "In addition, trucks will be allowed to enter India's capital only after 11 p.m., two hours later than the existing restriction" + context_tokens = tokenizer(context) + + answer_start_index = 72 + answer = "11 p.m" + + result = SQuADTransform._get_answer_spans(context, context_tokens, + [answer], [answer_start_index]) + + assert result == [(15, 16)] + + +def test_get_answer_spans_unicode(): + tokenizer = BiDAFTokenizer() + + context = "Back in Warsaw that year, Chopin heard Niccolò Paganini play" + context_tokens = tokenizer(context, lower_case=True) + + answer_start_index = 39 + answer = "Niccolò Paganini" + + result = SQuADTransform._get_answer_spans(context, context_tokens, + [answer], [answer_start_index]) + + assert result == [(8, 9)] + + +def test_get_answer_spans_after_comma(): + tokenizer = BiDAFTokenizer() + + context = "Chopin's successes as a composer and performer opened the door to western Europe for him, and on 2 November 1830, he set out," + context_tokens = tokenizer(context, lower_case=True) + + answer_start_index = 108 + answer = "1830" + + result = SQuADTransform._get_answer_spans(context, context_tokens, + [answer], [answer_start_index]) + + assert result == [(22, 22)] + + +def test_get_answer_spans_after_quotes(): + tokenizer = BiDAFTokenizer() + + context = "In the film Knute Rockne, All American, Knute Rockne (played by Pat O'Brien) delivers the famous ""Win one for the Gipper"" speech, at which point the background music swells with the ""Notre Dame Victory March"". George Gipp was played by Ronald Reagan, whose nickname ""The Gipper"" was derived from this role. This scene was parodied in the movie Airplane! with the same background music, only this time honoring George Zipp, one of Ted Striker's former comrades. The song also was prominent in the movie Rudy, with Sean Astin as Daniel ""Rudy"" Ruettiger, who harbored dreams of playing football at the University of Notre Dame despite significant obstacles." + context_tokens = tokenizer(context, lower_case=True) + indices = SQuADTransform.get_char_indices(context, context_tokens) + + answer_start_char_index = 267 + answer = "The Gipper" + + result = SQuADTransform._get_answer_spans(context, context_tokens, + [answer], [answer_start_char_index]) + + assert result == [(54, 55)] + + +def test_get_char_indices(): + context = "to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary." + tokenizer = BiDAFTokenizer() + context_tokens = tokenizer(context, lower_case=True) + + result = SQuADTransform.get_char_indices(context, context_tokens) + assert len(result) == len(context_tokens) + + +def test_tokenizer_split_new_lines(): + context = "that are of equal energy\u2014i.e., degenerate\u2014is a configuration termed a spin triplet state. Hence, the ground state of the O\n2 molecule is referred to as triplet oxygen" + + tokenizer = BiDAFTokenizer() + context_tokens = tokenizer(context, lower_case=True) + + assert len(context_tokens) == 35 + + +def test_polyak_averaging(): + net = nn.HybridSequential() + net.add(nn.Dense(5), nn.Dense(3), nn.Dense(2)) + net.initialize(init.Xavier()) + net.hybridize() + + ema = None + loss_fn = SoftmaxCrossEntropyLoss() + trainer = Trainer(net.collect_params(), "sgd", {"learning_rate": 0.5}) + + train_data = mx.random.uniform(-0.1, 0.1, shape=(5, 10)) + train_label = mx.nd.array([0, 1, 1, 0, 1]) + + for i in range(3): + with autograd.record(): + o = net(train_data) + loss = loss_fn(o, train_label) + + if i == 0: + ema = PolyakAveraging(net.collect_params(), decay=0.999) + + loss.backward() + trainer.step(5) + ema.update() + + assert ema.get_params() is not None + +def get_args(batch_size): + options = SimpleNamespace() + options.gpu = None + options.ctx_embedding_num_layers = 2 + options.embedding_size = 100 + options.dropout = 0.2 + options.ctx_embedding_num_layers = 2 + options.highway_num_layers = 2 + options.modeling_num_layers = 2 + options.output_num_layers = 2 + options.batch_size = batch_size + options.ctx_max_len = context_max_length + options.q_max_len = question_max_length + options.word_max_len = max_chars_per_word + options.epochs = 12 + options.save_dir = "output/" + options.filter_long_context = False + options.word_vocab_path = "" + options.char_vocab_path = "" + options.train_unk_token = False + + return options