diff --git a/docs/AdvancedNAS.md b/docs/AdvancedNAS.md new file mode 100644 index 0000000000..3d2dd986bb --- /dev/null +++ b/docs/AdvancedNAS.md @@ -0,0 +1,71 @@ +# Tutorial for Advanced Neural Architecture Search +Currently many of the NAS algorithms leverage the technique of **weight sharing** among trials to accelerate its training process. For example, [ENAS][1] delivers 1000x effiency with '_parameter sharing between child models_', compared with the previous [NASNet][2] algorithm. Other NAS algorithms such as [DARTS][3], [Network Morphism][4], and [Evolution][5] is also leveraging, or has the potential to leverage weight sharing. + +This is a tutorial on how to enable weight sharing in NNI. + +## Weight Sharing among trials +Currently we recommend sharing weights through NFS (Network File System), which supports sharing files across machines, and is light-weighted, (relatively) efficient. We also welcome contributions from the community on more efficient techniques. + +### NFS Setup +In NFS, files are physically stored on a server machine, and trials on the client machine can read/write those files in the same way that they access local files. + +#### Install NFS on server machine +First, install NFS server: +```bash +sudo apt-get install nfs-kernel-server +``` +Suppose `/tmp/nni/shared` is used as the physical storage, then run: +```bash +sudo mkdir -p /tmp/nni/shared +sudo echo "/tmp/nni/shared *(rw,sync,no_subtree_check,no_root_squash)" >> /etc/exports +sudo service nfs-kernel-server restart +``` +You can check if the above directory is successfully exported by NFS using `sudo showmount -e localhost` + +#### Install NFS on client machine +First, install NFS client: +```bash +sudo apt-get install nfs-common +``` +Then create & mount the mounted directory of shared files: +```bash +sudo mkdir -p /mnt/nfs/nni/ +sudo mount -t nfs 10.10.10.10:/tmp/nni/shared /mnt/nfs/nni +``` +where `10.10.10.10` should be replaced by the real IP of NFS server machine in practice. + +### Weight Sharing through NFS file +With the NFS setup, trial code can share model weight through loading & saving files. For example, in tensorflow: +```python +# save models +saver = tf.train.Saver() +saver.save(sess, os.path.join(params['save_path'], 'model.ckpt')) +# load models +tf.init_from_checkpoint(params['restore_path']) +``` +where `'save_path'` and `'restore_path'` in hyper-parameter can be managed by the tuner. + +## Asynchornous Dispatcher Mode for trial dependency control +The feature of weight sharing enables trials from different machines, in which most of the time **read after write** consistency must be assured. After all, the child model should not load parent model before parent trial finishes training. To deal with this, users can enable **asynchronous dispatcher mode** with `multiThread: true` in `config.yml` in NNI, where the dispatcher assign a tuner thread each time a `NEW_TRIAL` request comes in, and the tuner thread can decide when to submit a new trial by blocking and unblocking the thread itself. For example: +```python + def generate_parameters(self, parameter_id): + self.thread_lock.acquire() + indiv = # configuration for a new trial + self.events[parameter_id] = threading.Event() + self.thread_lock.release() + if indiv.parent_id is not None: + self.events[indiv.parent_id].wait() + + def receive_trial_result(self, parameter_id, parameters, reward): + self.thread_lock.acquire() + # code for processing trial results + self.thread_lock.release() + self.events[parameter_id].set() +``` + + +[1]: https://arxiv.org/abs/1802.03268 +[2]: https://arxiv.org/abs/1707.07012 +[3]: https://arxiv.org/abs/1806.09055 +[4]: https://arxiv.org/abs/1806.10282 +[5]: https://arxiv.org/abs/1703.01041 \ No newline at end of file diff --git a/examples/trials/ga_squad/trial.py b/examples/trials/ga_squad/trial.py index cb6640ac7a..815e88af4e 100644 --- a/examples/trials/ga_squad/trial.py +++ b/examples/trials/ga_squad/trial.py @@ -338,7 +338,7 @@ def train_with_graph(graph, qp_pairs, dev_qp_pairs): answers = generate_predict_json( position1, position2, ids, contexts) if save_path is not None: - with open(save_path + 'epoch%d.prediction' % epoch, 'w') as file: + with open(os.path.join(save_path, 'epoch%d.prediction' % epoch), 'w') as file: json.dump(answers, file) else: answers = json.dumps(answers) @@ -359,8 +359,8 @@ def train_with_graph(graph, qp_pairs, dev_qp_pairs): bestacc = acc if save_path is not None: - saver.save(sess, save_path + 'epoch%d.model' % epoch) - with open(save_path + 'epoch%d.score' % epoch, 'wb') as file: + saver.save(os.path.join(sess, save_path + 'epoch%d.model' % epoch)) + with open(os.path.join(save_path, 'epoch%d.score' % epoch), 'wb') as file: pickle.dump( (position1, position2, ids, contexts), file) logger.debug('epoch %d acc %g bestacc %g' % diff --git a/examples/trials/weight_sharing/ga_squad/attention.py b/examples/trials/weight_sharing/ga_squad/attention.py new file mode 100644 index 0000000000..812db53221 --- /dev/null +++ b/examples/trials/weight_sharing/ga_squad/attention.py @@ -0,0 +1,171 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import math + +import tensorflow as tf +from tensorflow.python.ops.rnn_cell_impl import RNNCell + + +def _get_variable(variable_dict, name, shape, initializer=None, dtype=tf.float32): + if name not in variable_dict: + variable_dict[name] = tf.get_variable( + name=name, shape=shape, initializer=initializer, dtype=dtype) + return variable_dict[name] + + +class DotAttention: + ''' + DotAttention + ''' + + def __init__(self, name, + hidden_dim, + is_vanilla=True, + is_identity_transform=False, + need_padding=False): + self._name = '/'.join([name, 'dot_att']) + self._hidden_dim = hidden_dim + self._is_identity_transform = is_identity_transform + self._need_padding = need_padding + self._is_vanilla = is_vanilla + self._var = {} + + @property + def is_identity_transform(self): + return self._is_identity_transform + + @property + def is_vanilla(self): + return self._is_vanilla + + @property + def need_padding(self): + return self._need_padding + + @property + def hidden_dim(self): + return self._hidden_dim + + @property + def name(self): + return self._name + + @property + def var(self): + return self._var + + def _get_var(self, name, shape, initializer=None): + with tf.variable_scope(self.name): + return _get_variable(self.var, name, shape, initializer) + + def _define_params(self, src_dim, tgt_dim): + hidden_dim = self.hidden_dim + self._get_var('W', [src_dim, hidden_dim]) + if not self.is_vanilla: + self._get_var('V', [src_dim, hidden_dim]) + if self.need_padding: + self._get_var('V_s', [src_dim, src_dim]) + self._get_var('V_t', [tgt_dim, tgt_dim]) + if not self.is_identity_transform: + self._get_var('T', [tgt_dim, src_dim]) + self._get_var('U', [tgt_dim, hidden_dim]) + self._get_var('b', [1, hidden_dim]) + self._get_var('v', [hidden_dim, 1]) + + def get_pre_compute(self, s): + ''' + :param s: [src_sequence, batch_size, src_dim] + :return: [src_sequence, batch_size. hidden_dim] + ''' + hidden_dim = self.hidden_dim + src_dim = s.get_shape().as_list()[-1] + assert src_dim is not None, 'src dim must be defined' + W = self._get_var('W', shape=[src_dim, hidden_dim]) + b = self._get_var('b', shape=[1, hidden_dim]) + return tf.tensordot(s, W, [[2], [0]]) + b + + def get_prob(self, src, tgt, mask, pre_compute, return_logits=False): + ''' + :param s: [src_sequence_length, batch_size, src_dim] + :param h: [batch_size, tgt_dim] or [tgt_sequence_length, batch_size, tgt_dim] + :param mask: [src_sequence_length, batch_size]\ + or [tgt_sequence_length, src_sequence_length, batch_sizse] + :param pre_compute: [src_sequence_length, batch_size, hidden_dim] + :return: [src_sequence_length, batch_size]\ + or [tgt_sequence_length, src_sequence_length, batch_size] + ''' + s_shape = src.get_shape().as_list() + h_shape = tgt.get_shape().as_list() + src_dim = s_shape[-1] + tgt_dim = h_shape[-1] + assert src_dim is not None, 'src dimension must be defined' + assert tgt_dim is not None, 'tgt dimension must be defined' + + self._define_params(src_dim, tgt_dim) + + if len(h_shape) == 2: + tgt = tf.expand_dims(tgt, 0) + if pre_compute is None: + pre_compute = self.get_pre_compute(src) + + buf0 = pre_compute + buf1 = tf.tensordot(tgt, self.var['U'], axes=[[2], [0]]) + buf2 = tf.tanh(tf.expand_dims(buf0, 0) + tf.expand_dims(buf1, 1)) + + if not self.is_vanilla: + xh1 = tgt + xh2 = tgt + s1 = src + if self.need_padding: + xh1 = tf.tensordot(xh1, self.var['V_t'], 1) + xh2 = tf.tensordot(xh2, self.var['S_t'], 1) + s1 = tf.tensordot(s1, self.var['V_s'], 1) + if not self.is_identity_transform: + xh1 = tf.tensordot(xh1, self.var['T'], 1) + xh2 = tf.tensordot(xh2, self.var['T'], 1) + buf3 = tf.expand_dims(s1, 0) * tf.expand_dims(xh1, 1) + buf3 = tf.tanh(tf.tensordot(buf3, self.var['V'], axes=[[3], [0]])) + buf = tf.reshape(tf.tanh(buf2 + buf3), shape=tf.shape(buf3)) + else: + buf = buf2 + v = self.var['v'] + e = tf.tensordot(buf, v, [[3], [0]]) + e = tf.squeeze(e, axis=[3]) + tmp = tf.reshape(e + (mask - 1) * 10000.0, shape=tf.shape(e)) + prob = tf.nn.softmax(tmp, 1) + if len(h_shape) == 2: + prob = tf.squeeze(prob, axis=[0]) + tmp = tf.squeeze(tmp, axis=[0]) + if return_logits: + return prob, tmp + return prob + + def get_att(self, s, prob): + ''' + :param s: [src_sequence_length, batch_size, src_dim] + :param prob: [src_sequence_length, batch_size]\ + or [tgt_sequence_length, src_sequence_length, batch_size] + :return: [batch_size, src_dim] or [tgt_sequence_length, batch_size, src_dim] + ''' + buf = s * tf.expand_dims(prob, axis=-1) + att = tf.reduce_sum(buf, axis=-3) + return att diff --git a/examples/trials/weight_sharing/ga_squad/config_remote.yml b/examples/trials/weight_sharing/ga_squad/config_remote.yml new file mode 100644 index 0000000000..a07ab055cb --- /dev/null +++ b/examples/trials/weight_sharing/ga_squad/config_remote.yml @@ -0,0 +1,31 @@ +authorName: default +experimentName: ga_squad_weight_sharing +trialConcurrency: 2 +maxExecDuration: 1h +maxTrialNum: 200 +#choice: local, remote, pai +trainingServicePlatform: remote +#choice: true, false +useAnnotation: false +multiThread: true +tuner: + codeDir: ../../../tuners/weight_sharing/ga_customer_tuner + classFileName: customer_tuner.py + className: CustomerTuner + classArgs: + optimize_mode: maximize + population_size: 32 + save_dir_root: /mnt/nfs/nni/ga_squad +trial: + command: python3 trial.py --input_file /mnt/nfs/nni/train-v1.1.json --dev_file /mnt/nfs/nni/dev-v1.1.json --max_epoch 1 --embedding_file /mnt/nfs/nni/glove.6B.300d.txt + codeDir: . + gpuNum: 1 +machineList: + - ip: remote-ip-0 + port: 8022 + username: root + passwd: screencast + - ip: remote-ip-1 + port: 8022 + username: root + passwd: screencast diff --git a/examples/trials/weight_sharing/ga_squad/data.py b/examples/trials/weight_sharing/ga_squad/data.py new file mode 100644 index 0000000000..074b5a5b28 --- /dev/null +++ b/examples/trials/weight_sharing/ga_squad/data.py @@ -0,0 +1,269 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +''' +Data processing script for the QA model. +''' + +import csv +import json +from random import shuffle + +import numpy as np + + +class WhitespaceTokenizer: + ''' + Tokenizer for whitespace + ''' + + def tokenize(self, text): + ''' + tokenize function in Tokenizer. + ''' + start = -1 + tokens = [] + for i, character in enumerate(text): + if character == ' ' or character == '\t': + if start >= 0: + word = text[start:i] + tokens.append({ + 'word': word, + 'original_text': word, + 'char_begin': start, + 'char_end': i}) + start = -1 + else: + if start < 0: + start = i + if start >= 0: + tokens.append({ + 'word': text[start:len(text)], + 'original_text': text[start:len(text)], + 'char_begin': start, + 'char_end': len(text) + }) + return tokens + + +def load_from_file(path, fmt=None, is_training=True): + ''' + load data from file + ''' + if fmt is None: + fmt = 'squad' + assert fmt in ['squad', 'csv'], 'input format must be squad or csv' + qp_pairs = [] + if fmt == 'squad': + with open(path) as data_file: + data = json.load(data_file)['data'] + for doc in data: + for paragraph in doc['paragraphs']: + passage = paragraph['context'] + for qa_pair in paragraph['qas']: + question = qa_pair['question'] + qa_id = qa_pair['id'] + if not is_training: + qp_pairs.append( + {'passage': passage, 'question': question, 'id': qa_id}) + else: + for answer in qa_pair['answers']: + answer_begin = int(answer['answer_start']) + answer_end = answer_begin + len(answer['text']) + qp_pairs.append({'passage': passage, + 'question': question, + 'id': qa_id, + 'answer_begin': answer_begin, + 'answer_end': answer_end}) + else: + with open(path, newline='') as csvfile: + reader = csv.reader(csvfile, delimiter='\t') + line_num = 0 + for row in reader: + qp_pairs.append( + {'passage': row[1], 'question': row[0], 'id': line_num}) + line_num += 1 + return qp_pairs + + +def tokenize(qp_pair, tokenizer=None, is_training=False): + ''' + tokenize function. + ''' + question_tokens = tokenizer.tokenize(qp_pair['question']) + passage_tokens = tokenizer.tokenize(qp_pair['passage']) + if is_training: + question_tokens = question_tokens[:300] + passage_tokens = passage_tokens[:300] + passage_tokens.insert( + 0, {'word': '', 'original_text': '', 'char_begin': 0, 'char_end': 0}) + passage_tokens.append( + {'word': '', 'original_text': '', 'char_begin': 0, 'char_end': 0}) + qp_pair['question_tokens'] = question_tokens + qp_pair['passage_tokens'] = passage_tokens + + +def collect_vocab(qp_pairs): + ''' + Build the vocab from corpus. + ''' + vocab = set() + for qp_pair in qp_pairs: + for word in qp_pair['question_tokens']: + vocab.add(word['word']) + for word in qp_pair['passage_tokens']: + vocab.add(word['word']) + return vocab + + +def shuffle_step(entries, step): + ''' + Shuffle the step + ''' + answer = [] + for i in range(0, len(entries), step): + sub = entries[i:i+step] + shuffle(sub) + answer += sub + return answer + + +def get_batches(qp_pairs, batch_size, need_sort=True): + ''' + Get batches data and shuffle. + ''' + if need_sort: + qp_pairs = sorted(qp_pairs, key=lambda qp: ( + len(qp['passage_tokens']), qp['id']), reverse=True) + batches = [{'qp_pairs': qp_pairs[i:(i + batch_size)]} + for i in range(0, len(qp_pairs), batch_size)] + shuffle(batches) + return batches + + +def get_char_input(data, char_dict, max_char_length): + ''' + Get char input. + ''' + batch_size = len(data) + sequence_length = max(len(d) for d in data) + char_id = np.zeros((max_char_length, sequence_length, + batch_size), dtype=np.int32) + char_lengths = np.zeros((sequence_length, batch_size), dtype=np.float32) + for batch_idx in range(0, min(len(data), batch_size)): + batch_data = data[batch_idx] + for sample_idx in range(0, min(len(batch_data), sequence_length)): + word = batch_data[sample_idx]['word'] + char_lengths[sample_idx, batch_idx] = min( + len(word), max_char_length) + for i in range(0, min(len(word), max_char_length)): + char_id[i, sample_idx, batch_idx] = get_id(char_dict, word[i]) + return char_id, char_lengths + + +def get_word_input(data, word_dict, embed, embed_dim): + ''' + Get word input. + ''' + batch_size = len(data) + max_sequence_length = max(len(d) for d in data) + sequence_length = max_sequence_length + word_input = np.zeros((max_sequence_length, batch_size, + embed_dim), dtype=np.float32) + ids = np.zeros((sequence_length, batch_size), dtype=np.int32) + masks = np.zeros((sequence_length, batch_size), dtype=np.float32) + lengths = np.zeros([batch_size], dtype=np.int32) + + for batch_idx in range(0, min(len(data), batch_size)): + batch_data = data[batch_idx] + + lengths[batch_idx] = len(batch_data) + + for sample_idx in range(0, min(len(batch_data), sequence_length)): + word = batch_data[sample_idx]['word'].lower() + if word in word_dict.keys(): + word_input[sample_idx, batch_idx] = embed[word_dict[word]] + ids[sample_idx, batch_idx] = word_dict[word] + masks[sample_idx, batch_idx] = 1 + + word_input = np.reshape(word_input, (-1, embed_dim)) + return word_input, ids, masks, lengths + + +def get_word_index(tokens, char_index): + ''' + Given word return word index. + ''' + for (i, token) in enumerate(tokens): + if token['char_end'] == 0: + continue + if token['char_begin'] <= char_index and char_index <= token['char_end']: + return i + return 0 + + +def get_answer_begin_end(data): + ''' + Get answer's index of begin and end. + ''' + begin = [] + end = [] + for qa_pair in data: + tokens = qa_pair['passage_tokens'] + char_begin = qa_pair['answer_begin'] + char_end = qa_pair['answer_end'] + word_begin = get_word_index(tokens, char_begin) + word_end = get_word_index(tokens, char_end) + begin.append(word_begin) + end.append(word_end) + return np.asarray(begin), np.asarray(end) + + +def get_id(word_dict, word): + ''' + Given word, return word id. + ''' + if word in word_dict.keys(): + return word_dict[word] + return word_dict[''] + + +def get_buckets(min_length, max_length, bucket_count): + ''' + Get bucket by length. + ''' + if bucket_count <= 0: + return [max_length] + unit_length = int((max_length - min_length) // (bucket_count)) + buckets = [min_length + unit_length * + (i + 1) for i in range(0, bucket_count)] + buckets[-1] = max_length + return buckets + + +def find_bucket(length, buckets): + ''' + Find bucket. + ''' + for bucket in buckets: + if length <= bucket: + return bucket + return buckets[-1] diff --git a/examples/trials/weight_sharing/ga_squad/download.sh b/examples/trials/weight_sharing/ga_squad/download.sh new file mode 100644 index 0000000000..308fbaedbf --- /dev/null +++ b/examples/trials/weight_sharing/ga_squad/download.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json +wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json +wget http://nlp.stanford.edu/data/glove.840B.300d.zip +unzip glove.840B.300d.zip \ No newline at end of file diff --git a/examples/trials/weight_sharing/ga_squad/evaluate.py b/examples/trials/weight_sharing/ga_squad/evaluate.py new file mode 100644 index 0000000000..d2bc208cf4 --- /dev/null +++ b/examples/trials/weight_sharing/ga_squad/evaluate.py @@ -0,0 +1,169 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +''' +Evaluation scripts for QA model. +''' + +from __future__ import print_function +from collections import Counter +import string +import re +import argparse +import json +import sys + + +def normalize_answer(str_input): + """Lower text and remove punctuation, articles and extra whitespace.""" + def remove_articles(text): + ''' + Remove "a|an|the" + ''' + return re.sub(r'\b(a|an|the)\b', ' ', text) + + def white_space_fix(text): + ''' + Remove unnessary whitespace + ''' + return ' '.join(text.split()) + + def remove_punc(text): + ''' + Remove punc + ''' + exclude = set(string.punctuation) + return ''.join(ch for ch in text if ch not in exclude) + + def lower(text): + ''' + Change string to lower form. + ''' + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(str_input)))) + + +def f1_score(prediction, ground_truth): + ''' + Calculate the f1 score. + ''' + prediction_tokens = normalize_answer(prediction).split() + ground_truth_tokens = normalize_answer(ground_truth).split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1_result = (2 * precision * recall) / (precision + recall) + return f1_result + + +def exact_match_score(prediction, ground_truth): + ''' + Calculate the match score with prediction and ground truth. + ''' + return normalize_answer(prediction) == normalize_answer(ground_truth) + + +def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): + ''' + Metric max over the ground truths. + ''' + scores_for_ground_truths = [] + for ground_truth in ground_truths: + score = metric_fn(prediction, ground_truth) + scores_for_ground_truths.append(score) + return max(scores_for_ground_truths) + + +def _evaluate(dataset, predictions): + ''' + Evaluate function. + ''' + f1_result = exact_match = total = 0 + count = 0 + for article in dataset: + for paragraph in article['paragraphs']: + for qa_pair in paragraph['qas']: + total += 1 + if qa_pair['id'] not in predictions: + count += 1 + continue + ground_truths = list( + map(lambda x: x['text'], qa_pair['answers'])) + prediction = predictions[qa_pair['id']] + exact_match += metric_max_over_ground_truths( + exact_match_score, prediction, ground_truths) + f1_result += metric_max_over_ground_truths( + f1_score, prediction, ground_truths) + print('total', total, 'exact_match', + exact_match, 'unanswer_question ', count) + exact_match = 100.0 * exact_match / total + f1_result = 100.0 * f1_result / total + return {'exact_match': exact_match, 'f1': f1_result} + + +def evaluate(data_file, pred_file): + ''' + Evaluate. + ''' + expected_version = '1.1' + with open(data_file) as dataset_file: + dataset_json = json.load(dataset_file) + if dataset_json['version'] != expected_version: + print('Evaluation expects v-' + expected_version + + ', but got dataset with v-' + dataset_json['version'], + file=sys.stderr) + dataset = dataset_json['data'] + with open(pred_file) as prediction_file: + predictions = json.load(prediction_file) + # print(json.dumps(evaluate(dataset, predictions))) + result = _evaluate(dataset, predictions) + # print('em:', result['exact_match'], 'f1:', result['f1']) + return result['exact_match'] + + +def evaluate_with_predictions(data_file, predictions): + ''' + Evalutate with predictions/ + ''' + expected_version = '1.1' + with open(data_file) as dataset_file: + dataset_json = json.load(dataset_file) + if dataset_json['version'] != expected_version: + print('Evaluation expects v-' + expected_version + + ', but got dataset with v-' + dataset_json['version'], + file=sys.stderr) + dataset = dataset_json['data'] + result = _evaluate(dataset, predictions) + return result['exact_match'] + + +if __name__ == '__main__': + EXPECT_VERSION = '1.1' + parser = argparse.ArgumentParser( + description='Evaluation for SQuAD ' + EXPECT_VERSION) + parser.add_argument('dataset_file', help='Dataset file') + parser.add_argument('prediction_file', help='Prediction File') + args = parser.parse_args() + print(evaluate(args.dataset_file, args.prediction_file)) diff --git a/examples/trials/weight_sharing/ga_squad/graph.py b/examples/trials/weight_sharing/ga_squad/graph.py new file mode 100644 index 0000000000..8e675a06ff --- /dev/null +++ b/examples/trials/weight_sharing/ga_squad/graph.py @@ -0,0 +1,336 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +''' +Graph is customed-define class, this module contains related class and function about graph. +''' + + +import copy +import hashlib +import logging +import json +import random +from collections import deque +from enum import Enum, unique +from typing import Iterable + +import numpy as np + +_logger = logging.getLogger('ga_squad_graph') + +@unique +class LayerType(Enum): + ''' + Layer type + ''' + attention = 0 + self_attention = 1 + rnn = 2 + input = 3 + output = 4 + +class Layer(object): + ''' + Layer class, which contains the information of graph. + ''' + def __init__(self, graph_type, inputs=None, output=None, size=None, hash_id=None): + self.input = inputs if inputs is not None else [] + self.output = output if output is not None else [] + self.graph_type = graph_type + self.is_delete = False + self.size = size + self.hash_id = hash_id + if graph_type == LayerType.attention.value: + self.input_size = 2 + self.output_size = 1 + elif graph_type == LayerType.rnn.value: + self.input_size = 1 + self.output_size = 1 + elif graph_type == LayerType.self_attention.value: + self.input_size = 1 + self.output_size = 1 + elif graph_type == LayerType.input.value: + self.input_size = 0 + self.output_size = 1 + if self.hash_id is None: + hasher = hashlib.md5() + hasher.update(np.random.bytes(100)) + self.hash_id = hasher.hexdigest() + elif graph_type == LayerType.output.value: + self.input_size = 1 + self.output_size = 0 + else: + raise ValueError('Unsupported LayerType: {}'.format(graph_type)) + + def update_hash(self, layers: Iterable): + """ + Calculation of `hash_id` of Layer. Which is determined by the properties of itself, and the `hash_id`s of input layers + """ + if self.graph_type == LayerType.input.value: + return + hasher = hashlib.md5() + hasher.update(LayerType(self.graph_type).name.encode('ascii')) + hasher.update(str(self.size).encode('ascii')) + for i in self.input: + if layers[i].hash_id is None: + raise ValueError('Hash id of layer {}: {} not generated!'.format(i, layers[i])) + hasher.update(layers[i].hash_id.encode('ascii')) + self.hash_id = hasher.hexdigest() + + def set_size(self, graph_id, size): + ''' + Set size. + ''' + if self.graph_type == LayerType.attention.value: + if self.input[0] == graph_id: + self.size = size + if self.graph_type == LayerType.rnn.value: + self.size = size + if self.graph_type == LayerType.self_attention.value: + self.size = size + if self.graph_type == LayerType.output.value: + if self.size != size: + return False + return True + + def clear_size(self): + ''' + Clear size + ''' + if self.graph_type == LayerType.attention.value or \ + LayerType.rnn.value or LayerType.self_attention.value: + self.size = None + + def __str__(self): + return 'input:' + str(self.input) + ' output:' + str(self.output) + ' type:' + str(self.graph_type) + ' is_delete:' + str(self.is_delete) + ' size:' + str(self.size) + +def graph_dumps(graph): + ''' + Dump the graph. + ''' + return json.dumps(graph, default=lambda obj: obj.__dict__) + +def graph_loads(graph_json): + ''' + Load graph + ''' + layers = [] + for layer in graph_json['layers']: + layer_info = Layer(layer['graph_type'], layer['input'], layer['output'], layer['size'], layer['hash_id']) + layer_info.is_delete = layer['is_delete'] + _logger.debug('append layer {}'.format(layer_info)) + layers.append(layer_info) + graph = Graph(graph_json['max_layer_num'], graph_json['min_layer_num'], [], [], []) + graph.layers = layers + _logger.debug('graph {} loaded'.format(graph)) + return graph + +class Graph(object): + ''' + Customed Graph class. + ''' + def __init__(self, max_layer_num, min_layer_num, inputs, output, hide): + self.layers = [] + self.max_layer_num = max_layer_num + self.min_layer_num = min_layer_num + assert min_layer_num < max_layer_num + + for layer in inputs: + self.layers.append(layer) + for layer in output: + self.layers.append(layer) + if hide is not None: + for layer in hide: + self.layers.append(layer) + assert self.is_legal() + + def is_topology(self, layers=None): + ''' + valid the topology + ''' + if layers is None: + layers = self.layers + layers_nodle = [] + result = [] + for i, layer in enumerate(layers): + if layer.is_delete is False: + layers_nodle.append(i) + while True: + flag_break = True + layers_toremove = [] + for layer1 in layers_nodle: + flag_arrive = True + for layer2 in layers[layer1].input: + if layer2 in layers_nodle: + flag_arrive = False + if flag_arrive is True: + for layer2 in layers[layer1].output: + # Size is error + if layers[layer2].set_size(layer1, layers[layer1].size) is False: + return False + layers_toremove.append(layer1) + result.append(layer1) + flag_break = False + for layer in layers_toremove: + layers_nodle.remove(layer) + result.append('|') + if flag_break: + break + # There is loop in graph || some layers can't to arrive + if layers_nodle: + return False + return result + + def layer_num(self, layers=None): + ''' + Reutn number of layer. + ''' + if layers is None: + layers = self.layers + layer_num = 0 + for layer in layers: + if layer.is_delete is False and layer.graph_type != LayerType.input.value\ + and layer.graph_type != LayerType.output.value: + layer_num += 1 + return layer_num + + def is_legal(self, layers=None): + ''' + Judge whether is legal for layers + ''' + if layers is None: + layers = self.layers + + for layer in layers: + if layer.is_delete is False: + if len(layer.input) != layer.input_size: + return False + if len(layer.output) < layer.output_size: + return False + + # layer_num <= max_layer_num + if self.layer_num(layers) > self.max_layer_num: + return False + + # There is loop in graph || some layers can't to arrive + if self.is_topology(layers) is False: + return False + + return True + + def update_hash(self): + """ + update hash id of each layer, in topological order/recursively + hash id will be used in weight sharing + """ + _logger.debug('update hash') + layer_in_cnt = [len(layer.input) for layer in self.layers] + topo_queue = deque([i for i, layer in enumerate(self.layers) if not layer.is_delete and layer.graph_type == LayerType.input.value]) + while topo_queue: + layer_i = topo_queue.pop() + self.layers[layer_i].update_hash(self.layers) + for layer_j in self.layers[layer_i].output: + layer_in_cnt[layer_j] -= 1 + if layer_in_cnt[layer_j] == 0: + topo_queue.appendleft(layer_j) + + def mutation(self, only_add=False): + ''' + Mutation for a graph + ''' + types = [] + if self.layer_num() < self.max_layer_num: + types.append(0) + types.append(1) + if self.layer_num() > self.min_layer_num and only_add is False: + types.append(2) + types.append(3) + # 0 : add a layer , delete a edge + # 1 : add a layer , change a edge + # 2 : delete a layer, delete a edge + # 3 : delete a layer, change a edge + graph_type = random.choice(types) + layer_type = random.choice([LayerType.attention.value,\ + LayerType.self_attention.value, LayerType.rnn.value]) + layers = copy.deepcopy(self.layers) + cnt_try = 0 + while True: + layers_in = [] + layers_out = [] + layers_del = [] + for i, layer in enumerate(layers): + if layer.is_delete is False: + if layer.graph_type != LayerType.output.value: + layers_in.append(i) + if layer.graph_type != LayerType.input.value: + layers_out.append(i) + if layer.graph_type != LayerType.output.value\ + and layer.graph_type != LayerType.input.value: + layers_del.append(i) + if graph_type <= 1: + new_id = len(layers) + out = random.choice(layers_out) + inputs = [] + output = [out] + pos = random.randint(0, len(layers[out].input) - 1) + last_in = layers[out].input[pos] + layers[out].input[pos] = new_id + if graph_type == 0: + layers[last_in].output.remove(out) + if graph_type == 1: + layers[last_in].output.remove(out) + layers[last_in].output.append(new_id) + inputs = [last_in] + lay = Layer(graph_type=layer_type, inputs=inputs, output=output) + while len(inputs) < lay.input_size: + layer1 = random.choice(layers_in) + inputs.append(layer1) + layers[layer1].output.append(new_id) + lay.input = inputs + layers.append(lay) + else: + layer1 = random.choice(layers_del) + for layer2 in layers[layer1].output: + layers[layer2].input.remove(layer1) + if graph_type == 2: + random_in = random.choice(layers_in) + else: + random_in = random.choice(layers[layer1].input) + layers[layer2].input.append(random_in) + layers[random_in].output.append(layer2) + for layer2 in layers[layer1].input: + layers[layer2].output.remove(layer1) + layers[layer1].is_delete = True + + if self.is_legal(layers): + self.layers = layers + break + else: + layers = copy.deepcopy(self.layers) + cnt_try += 1 + self.update_hash() + + def __str__(self): + info = "" + for l_id, layer in enumerate(self.layers): + if layer.is_delete is False: + info += 'id:%d ' % l_id + str(layer) + '\n' + return info diff --git a/examples/trials/weight_sharing/ga_squad/graph_to_tf.py b/examples/trials/weight_sharing/ga_squad/graph_to_tf.py new file mode 100644 index 0000000000..2712d531ca --- /dev/null +++ b/examples/trials/weight_sharing/ga_squad/graph_to_tf.py @@ -0,0 +1,342 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import tensorflow as tf +from rnn import XGRUCell +from util import dropout +from graph import LayerType + + +def normalize(inputs, + epsilon=1e-8, + scope="ln"): + '''Applies layer normalization. + + Args: + inputs: A tensor with 2 or more dimensions, where the first dimension has + `batch_size`. + epsilon: A floating number. A very small number for preventing ZeroDivision Error. + scope: Optional scope for `variable_scope`. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + A tensor with the same shape and data dtype as `inputs`. + ''' + with tf.variable_scope(scope): + inputs_shape = inputs.get_shape() + params_shape = inputs_shape[-1:] + + mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True) + beta = tf.Variable(tf.zeros(params_shape)) + gamma = tf.Variable(tf.ones(params_shape)) + normalized = (inputs - mean) / ((variance + epsilon) ** (.5)) + outputs = gamma * normalized + beta + + return outputs + + +def multihead_attention(queries, + keys, + scope="multihead_attention", + num_units=None, + num_heads=4, + dropout_rate=0, + is_training=True, + causality=False): + '''Applies multihead attention. + + Args: + queries: A 3d tensor with shape of [N, T_q, C_q]. + keys: A 3d tensor with shape of [N, T_k, C_k]. + num_units: A cdscalar. Attention size. + dropout_rate: A floating point number. + is_training: Boolean. Controller of mechanism for dropout. + causality: Boolean. If true, units that reference the future are masked. + num_heads: An int. Number of heads. + scope: Optional scope for `variable_scope`. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns + A 3d tensor with shape of (N, T_q, C) + ''' + global look5 + with tf.variable_scope(scope): + # Set the fall back option for num_units + if num_units is None: + num_units = queries.get_shape().as_list()[-1] + + Q_ = [] + K_ = [] + V_ = [] + for head_i in range(num_heads): + Q = tf.layers.dense(queries, num_units / num_heads, + activation=tf.nn.relu, name='Query' + str(head_i)) # (N, T_q, C) + K = tf.layers.dense(keys, num_units / num_heads, + activation=tf.nn.relu, name='Key' + str(head_i)) # (N, T_k, C) + V = tf.layers.dense(keys, num_units / num_heads, + activation=tf.nn.relu, name='Value' + str(head_i)) # (N, T_k, C) + Q_.append(Q) + K_.append(K) + V_.append(V) + + # Split and concat + Q_ = tf.concat(Q_, axis=0) # (h*N, T_q, C/h) + K_ = tf.concat(K_, axis=0) # (h*N, T_k, C/h) + V_ = tf.concat(V_, axis=0) # (h*N, T_k, C/h) + + # Multiplication + outputs = tf.matmul(Q_, tf.transpose(K_, [0, 2, 1])) # (h*N, T_q, T_k) + + # Scale + outputs = outputs / (K_.get_shape().as_list()[-1] ** 0.5) + + # Key Masking + key_masks = tf.sign(tf.abs(tf.reduce_sum(keys, axis=-1))) # (N, T_k) + key_masks = tf.tile(key_masks, [num_heads, 1]) # (h*N, T_k) + key_masks = tf.tile(tf.expand_dims(key_masks, 1), + [1, tf.shape(queries)[1], 1]) # (h*N, T_q, T_k) + + paddings = tf.ones_like(outputs) * (-2 ** 32 + 1) + outputs = tf.where(tf.equal(key_masks, 0), paddings, + outputs) # (h*N, T_q, T_k) + + # Causality = Future blinding + if causality: + diag_vals = tf.ones_like(outputs[0, :, :]) # (T_q, T_k) + tril = tf.contrib.linalg.LinearOperatorTriL( + diag_vals).to_dense() # (T_q, T_k) + masks = tf.tile(tf.expand_dims(tril, 0), + [tf.shape(outputs)[0], 1, 1]) # (h*N, T_q, T_k) + + paddings = tf.ones_like(masks) * (-2 ** 32 + 1) + outputs = tf.where(tf.equal(masks, 0), paddings, + outputs) # (h*N, T_q, T_k) + + # Activation + look5 = outputs + outputs = tf.nn.softmax(outputs) # (h*N, T_q, T_k) + + # Query Masking + query_masks = tf.sign( + tf.abs(tf.reduce_sum(queries, axis=-1))) # (N, T_q) + query_masks = tf.tile(query_masks, [num_heads, 1]) # (h*N, T_q) + query_masks = tf.tile(tf.expand_dims( + query_masks, -1), [1, 1, tf.shape(keys)[1]]) # (h*N, T_q, T_k) + outputs *= query_masks # broadcasting. (N, T_q, C) + + # Dropouts + outputs = dropout(outputs, dropout_rate, is_training) + + # Weighted sum + outputs = tf.matmul(outputs, V_) # ( h*N, T_q, C/h) + + # Restore shape + outputs = tf.concat(tf.split(outputs, num_heads, + axis=0), axis=2) # (N, T_q, C) + + # Residual connection + if queries.get_shape().as_list()[-1] == num_units: + outputs += queries + + # Normalize + outputs = normalize(outputs, scope=scope) # (N, T_q, C) + + return outputs + + +def positional_encoding(inputs, + num_units=None, + zero_pad=True, + scale=True, + scope="positional_encoding", + reuse=None): + ''' + Return positinal embedding. + ''' + Shape = tf.shape(inputs) + N = Shape[0] + T = Shape[1] + num_units = Shape[2] + with tf.variable_scope(scope, reuse=reuse): + position_ind = tf.tile(tf.expand_dims(tf.range(T), 0), [N, 1]) + + # First part of the PE function: sin and cos argument + # Second part, apply the cosine to even columns and sin to odds. + X = tf.expand_dims(tf.cast(tf.range(T), tf.float32), axis=1) + Y = tf.expand_dims( + tf.cast(10000 ** -(2 * tf.range(num_units) / num_units), tf.float32), axis=0) + h1 = tf.cast((tf.range(num_units) + 1) % 2, tf.float32) + h2 = tf.cast((tf.range(num_units) % 2), tf.float32) + position_enc = tf.multiply(X, Y) + position_enc = tf.sin(position_enc) * tf.multiply(tf.ones_like(X), h1) + \ + tf.cos(position_enc) * tf.multiply(tf.ones_like(X), h2) + + # Convert to a tensor + lookup_table = position_enc + + if zero_pad: + lookup_table = tf.concat((tf.zeros(shape=[1, num_units]), + lookup_table[1:, :]), 0) + outputs = tf.nn.embedding_lookup(lookup_table, position_ind) + + if scale: + outputs = outputs * tf.sqrt(tf.cast(num_units, tf.float32)) + + return outputs + + +def feedforward(inputs, + num_units, + scope="multihead_attention"): + '''Point-wise feed forward net. + + Args: + inputs: A 3d tensor with shape of [N, T, C]. + num_units: A list of two integers. + scope: Optional scope for `variable_scope`. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + A 3d tensor with the same shape and dtype as inputs + ''' + with tf.variable_scope(scope): + # Inner layer + params = {"inputs": inputs, "filters": num_units[0], "kernel_size": 1, + "activation": tf.nn.relu, "use_bias": True} + outputs = tf.layers.conv1d(**params) + + # Readout layer + params = {"inputs": outputs, "filters": num_units[1], "kernel_size": 1, + "activation": None, "use_bias": True} + outputs = tf.layers.conv1d(**params) + + # Residual connection + outputs += inputs + + # Normalize + outputs = normalize(outputs) + + return outputs + + +def rnn(input_states, sequence_lengths, dropout_rate, is_training, num_units): + layer_cnt = 1 + states = [] + xs = tf.transpose(input_states, perm=[1, 0, 2]) + for i in range(0, layer_cnt): + xs = dropout(xs, dropout_rate, is_training) + with tf.variable_scope('layer_' + str(i)): + cell_fw = XGRUCell(num_units) + cell_bw = XGRUCell(num_units) + outputs, _ = tf.nn.bidirectional_dynamic_rnn( + cell_fw=cell_fw, + cell_bw=cell_bw, + dtype=tf.float32, + sequence_length=sequence_lengths, + inputs=xs, + time_major=True) + + y_lr, y_rl = outputs + xs = tf.concat([y_lr, y_rl], 2) + states.append(xs) + + return tf.transpose(dropout(tf.concat(states, axis=2), + dropout_rate, + is_training), perm=[1, 0, 2]) + + +def graph_to_network(input1, + input2, + input1_lengths, + input2_lengths, + p_graph, + dropout_rate, + is_training, + num_heads=1, + rnn_units=256): + topology = p_graph.is_topology() + layers = dict() + layers_sequence_lengths = dict() + num_units = input1.get_shape().as_list()[-1] + layers[0] = input1*tf.sqrt(tf.cast(num_units, tf.float32)) + \ + positional_encoding(input1, scale=False, zero_pad=False) + layers[1] = input2*tf.sqrt(tf.cast(num_units, tf.float32)) + layers[0] = dropout(layers[0], dropout_rate, is_training) + layers[1] = dropout(layers[1], dropout_rate, is_training) + layers_sequence_lengths[0] = input1_lengths + layers_sequence_lengths[1] = input2_lengths + for _, topo_i in enumerate(topology): + if topo_i == '|': + continue + + # Note: here we use the `hash_id` of layer as scope name, + # so that we can automatically load sharable weights from previous trained models + with tf.variable_scope(p_graph.layers[topo_i].hash_id, reuse=tf.AUTO_REUSE): + if p_graph.layers[topo_i].graph_type == LayerType.input.value: + continue + elif p_graph.layers[topo_i].graph_type == LayerType.attention.value: + with tf.variable_scope('attention'): + layer = multihead_attention(layers[p_graph.layers[topo_i].input[0]], + layers[p_graph.layers[topo_i].input[1]], + scope="multihead_attention", + dropout_rate=dropout_rate, + is_training=is_training, + num_heads=num_heads, + num_units=rnn_units * 2) + layer = feedforward(layer, scope="feedforward", + num_units=[rnn_units * 2 * 4, rnn_units * 2]) + layers[topo_i] = layer + layers_sequence_lengths[topo_i] = layers_sequence_lengths[ + p_graph.layers[topo_i].input[0]] + elif p_graph.layers[topo_i].graph_type == LayerType.self_attention.value: + with tf.variable_scope('self-attention'): + layer = multihead_attention(layers[p_graph.layers[topo_i].input[0]], + layers[p_graph.layers[topo_i].input[0]], + scope="multihead_attention", + dropout_rate=dropout_rate, + is_training=is_training, + num_heads=num_heads, + num_units=rnn_units * 2) + layer = feedforward(layer, scope="feedforward", + num_units=[rnn_units * 2 * 4, rnn_units * 2]) + layers[topo_i] = layer + layers_sequence_lengths[topo_i] = layers_sequence_lengths[ + p_graph.layers[topo_i].input[0]] + elif p_graph.layers[topo_i].graph_type == LayerType.rnn.value: + with tf.variable_scope('rnn'): + layer = rnn(layers[p_graph.layers[topo_i].input[0]], + layers_sequence_lengths[p_graph.layers[topo_i].input[0]], + dropout_rate, + is_training, + rnn_units) + layers[topo_i] = layer + layers_sequence_lengths[topo_i] = layers_sequence_lengths[ + p_graph.layers[topo_i].input[0]] + elif p_graph.layers[topo_i].graph_type == LayerType.output.value: + layers[topo_i] = layers[p_graph.layers[topo_i].input[0]] + if layers[topo_i].get_shape().as_list()[-1] != rnn_units * 1 * 2: + with tf.variable_scope('add_dense'): + layers[topo_i] = tf.layers.dense( + layers[topo_i], units=rnn_units*2) + return layers[2], layers[3] diff --git a/examples/trials/weight_sharing/ga_squad/rnn.py b/examples/trials/weight_sharing/ga_squad/rnn.py new file mode 100644 index 0000000000..82f7d070bf --- /dev/null +++ b/examples/trials/weight_sharing/ga_squad/rnn.py @@ -0,0 +1,118 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import tensorflow as tf +from tensorflow.python.ops.rnn_cell_impl import RNNCell + + +class GRU: + ''' + GRU class. + ''' + def __init__(self, name, input_dim, hidden_dim): + self.name = '/'.join([name, 'gru']) + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.w_matrix = None + self.U = None + self.bias = None + + def define_params(self): + ''' + Define parameters. + ''' + input_dim = self.input_dim + hidden_dim = self.hidden_dim + prefix = self.name + self.w_matrix = tf.Variable(tf.random_normal([input_dim, 3 * hidden_dim], stddev=0.1), + name='/'.join([prefix, 'W'])) + self.U = tf.Variable(tf.random_normal([hidden_dim, 3 * hidden_dim], stddev=0.1), + name='/'.join([prefix, 'U'])) + self.bias = tf.Variable(tf.random_normal([1, 3 * hidden_dim], stddev=0.1), + name='/'.join([prefix, 'b'])) + return self + + def build(self, x, h, mask=None): + ''' + Build the GRU cell. + ''' + xw = tf.split(tf.matmul(x, self.w_matrix) + self.bias, 3, 1) + hu = tf.split(tf.matmul(h, self.U), 3, 1) + r = tf.sigmoid(xw[0] + hu[0]) + z = tf.sigmoid(xw[1] + hu[1]) + h1 = tf.tanh(xw[2] + r * hu[2]) + next_h = h1 * (1 - z) + h * z + if mask is not None: + next_h = next_h * mask + h * (1 - mask) + return next_h + + def build_sequence(self, xs, masks, init, is_left_to_right): + ''' + Build GRU sequence. + ''' + states = [] + last = init + if is_left_to_right: + for i, xs_i in enumerate(xs): + h = self.build(xs_i, last, masks[i]) + states.append(h) + last = h + else: + for i in range(len(xs) - 1, -1, -1): + h = self.build(xs[i], last, masks[i]) + states.insert(0, h) + last = h + return states + + +class XGRUCell(RNNCell): + + def __init__(self, hidden_dim, reuse=None): + super(XGRUCell, self).__init__(self, _reuse=reuse) + self._num_units = hidden_dim + self._activation = tf.tanh + + @property + def state_size(self): + return self._num_units + + @property + def output_size(self): + return self._num_units + + def call(self, inputs, state): + + input_dim = inputs.get_shape()[-1] + assert input_dim is not None, "input dimension must be defined" + W = tf.get_variable( + name="W", shape=[input_dim, 3 * self._num_units], dtype=tf.float32) + U = tf.get_variable( + name='U', shape=[self._num_units, 3 * self._num_units], dtype=tf.float32) + b = tf.get_variable( + name='b', shape=[1, 3 * self._num_units], dtype=tf.float32) + + xw = tf.split(tf.matmul(inputs, W) + b, 3, 1) + hu = tf.split(tf.matmul(state, U), 3, 1) + r = tf.sigmoid(xw[0] + hu[0]) + z = tf.sigmoid(xw[1] + hu[1]) + h1 = self._activation(xw[2] + r * hu[2]) + next_h = h1 * (1 - z) + state * z + return next_h, next_h diff --git a/examples/trials/weight_sharing/ga_squad/train_model.py b/examples/trials/weight_sharing/ga_squad/train_model.py new file mode 100644 index 0000000000..b8240bc960 --- /dev/null +++ b/examples/trials/weight_sharing/ga_squad/train_model.py @@ -0,0 +1,263 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +''' +Train the network combined by RNN and attention. +''' + +import tensorflow as tf + +from attention import DotAttention +from rnn import XGRUCell +from util import dropout +from graph_to_tf import graph_to_network + + +class GAGConfig: + """The class for model hyper-parameter configuration.""" + def __init__(self): + self.batch_size = 128 + + self.dropout = 0.1 + + self.char_vcb_size = 1500 + self.max_char_length = 20 + self.char_embed_dim = 100 + + self.max_query_length = 40 + self.max_passage_length = 800 + + self.att_is_vanilla = True + self.att_need_padding = False + self.att_is_id = False + + self.ptr_dim = 70 + self.learning_rate = 0.1 + self.labelsmoothing = 0.1 + self.num_heads = 1 + self.rnn_units = 256 + + +class GAG: + """The class for the computation graph based QA model.""" + def __init__(self, cfg, embed, p_graph): + self.cfg = cfg + self.embed = embed + self.graph = p_graph + + self.query_word = None + self.query_mask = None + self.query_lengths = None + self.passage_word = None + self.passage_mask = None + self.passage_lengths = None + self.answer_begin = None + self.answer_end = None + self.query_char_ids = None + self.query_char_lengths = None + self.passage_char_ids = None + self.passage_char_lengths = None + self.passage_states = None + self.query_states = None + self.query_init = None + self.begin_prob = None + self.end_prob = None + self.loss = None + self.train_op = None + + + def build_net(self, is_training): + """Build the whole neural network for the QA model.""" + cfg = self.cfg + word_embed = tf.get_variable( + name='word_embed', initializer=self.embed, dtype=tf.float32, trainable=False) + char_embed = tf.get_variable(name='char_embed', + shape=[cfg.char_vcb_size, + cfg.char_embed_dim], + dtype=tf.float32) + + # [query_length, batch_size] + self.query_word = tf.placeholder(dtype=tf.int32, + shape=[None, None], + name='query_word') + self.query_mask = tf.placeholder(dtype=tf.float32, + shape=[None, None], + name='query_mask') + # [batch_size] + self.query_lengths = tf.placeholder( + dtype=tf.int32, shape=[None], name='query_lengths') + + # [passage_length, batch_size] + self.passage_word = tf.placeholder( + dtype=tf.int32, shape=[None, None], name='passage_word') + self.passage_mask = tf.placeholder( + dtype=tf.float32, shape=[None, None], name='passage_mask') + # [batch_size] + self.passage_lengths = tf.placeholder( + dtype=tf.int32, shape=[None], name='passage_lengths') + + if is_training: + self.answer_begin = tf.placeholder( + dtype=tf.int32, shape=[None], name='answer_begin') + self.answer_end = tf.placeholder( + dtype=tf.int32, shape=[None], name='answer_end') + + self.query_char_ids = tf.placeholder(dtype=tf.int32, + shape=[ + self.cfg.max_char_length, None, None], + name='query_char_ids') + # sequence_length, batch_size + self.query_char_lengths = tf.placeholder( + dtype=tf.int32, shape=[None, None], name='query_char_lengths') + + self.passage_char_ids = tf.placeholder(dtype=tf.int32, + shape=[ + self.cfg.max_char_length, None, None], + name='passage_char_ids') + # sequence_length, batch_size + self.passage_char_lengths = tf.placeholder(dtype=tf.int32, + shape=[None, None], + name='passage_char_lengths') + + query_char_states = self.build_char_states(char_embed=char_embed, + is_training=is_training, + reuse=False, + char_ids=self.query_char_ids, + char_lengths=self.query_char_lengths) + + passage_char_states = self.build_char_states(char_embed=char_embed, + is_training=is_training, + reuse=True, + char_ids=self.passage_char_ids, + char_lengths=self.passage_char_lengths) + + with tf.variable_scope("encoding") as scope: + query_states = tf.concat([tf.nn.embedding_lookup( + word_embed, self.query_word), query_char_states], axis=2) + scope.reuse_variables() + passage_states = tf.concat([tf.nn.embedding_lookup( + word_embed, self.passage_word), passage_char_states], axis=2) + passage_states = tf.transpose(passage_states, perm=[1, 0, 2]) + query_states = tf.transpose(query_states, perm=[1, 0, 2]) + self.passage_states = passage_states + self.query_states = query_states + + output, output2 = graph_to_network(passage_states, query_states, + self.passage_lengths, self.query_lengths, + self.graph, self.cfg.dropout, + is_training, num_heads=cfg.num_heads, + rnn_units=cfg.rnn_units) + + passage_att_mask = self.passage_mask + batch_size_x = tf.shape(self.query_lengths) + answer_h = tf.zeros( + tf.concat([batch_size_x, tf.constant([cfg.ptr_dim], dtype=tf.int32)], axis=0)) + + answer_context = tf.reduce_mean(output2, axis=1) + + query_init_w = tf.get_variable( + 'query_init_w', shape=[output2.get_shape().as_list()[-1], cfg.ptr_dim]) + self.query_init = query_init_w + answer_context = tf.matmul(answer_context, query_init_w) + + output = tf.transpose(output, perm=[1, 0, 2]) + + with tf.variable_scope('answer_ptr_layer'): + ptr_att = DotAttention('ptr', + hidden_dim=cfg.ptr_dim, + is_vanilla=self.cfg.att_is_vanilla, + is_identity_transform=self.cfg.att_is_id, + need_padding=self.cfg.att_need_padding) + answer_pre_compute = ptr_att.get_pre_compute(output) + ptr_gru = XGRUCell(hidden_dim=cfg.ptr_dim) + begin_prob, begin_logits = ptr_att.get_prob(output, answer_context, passage_att_mask, + answer_pre_compute, True) + att_state = ptr_att.get_att(output, begin_prob) + (_, answer_h) = ptr_gru.call(inputs=att_state, state=answer_h) + answer_context = answer_h + end_prob, end_logits = ptr_att.get_prob(output, answer_context, + passage_att_mask, answer_pre_compute, + True) + + self.begin_prob = tf.transpose(begin_prob, perm=[1, 0]) + self.end_prob = tf.transpose(end_prob, perm=[1, 0]) + begin_logits = tf.transpose(begin_logits, perm=[1, 0]) + end_logits = tf.transpose(end_logits, perm=[1, 0]) + + if is_training: + def label_smoothing(inputs, masks, epsilon=0.1): + """Modify target for label smoothing.""" + epsilon = cfg.labelsmoothing + num_of_channel = tf.shape(inputs)[-1] # number of channels + inputs = tf.cast(inputs, tf.float32) + return (((1 - epsilon) * inputs) + (epsilon / + tf.cast(num_of_channel, tf.float32))) * masks + cost1 = tf.reduce_mean( + tf.losses.softmax_cross_entropy(label_smoothing( + tf.one_hot(self.answer_begin, + depth=tf.shape(self.passage_word)[0]), + tf.transpose(self.passage_mask, perm=[1, 0])), begin_logits)) + cost2 = tf.reduce_mean( + tf.losses.softmax_cross_entropy( + label_smoothing(tf.one_hot(self.answer_end, + depth=tf.shape(self.passage_word)[0]), + tf.transpose(self.passage_mask, perm=[1, 0])), end_logits)) + + reg_ws = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) + l2_loss = tf.reduce_sum(reg_ws) + loss = cost1 + cost2 + l2_loss + self.loss = loss + + optimizer = tf.train.AdamOptimizer(learning_rate=cfg.learning_rate) + self.train_op = optimizer.minimize(self.loss) + + return tf.stack([self.begin_prob, self.end_prob]) + + def build_char_states(self, char_embed, is_training, reuse, char_ids, char_lengths): + """Build char embedding network for the QA model.""" + max_char_length = self.cfg.max_char_length + + inputs = dropout(tf.nn.embedding_lookup(char_embed, char_ids), + self.cfg.dropout, is_training) + inputs = tf.reshape( + inputs, shape=[max_char_length, -1, self.cfg.char_embed_dim]) + char_lengths = tf.reshape(char_lengths, shape=[-1]) + with tf.variable_scope('char_encoding', reuse=reuse): + cell_fw = XGRUCell(hidden_dim=self.cfg.char_embed_dim) + cell_bw = XGRUCell(hidden_dim=self.cfg.char_embed_dim) + _, (left_right, right_left) = tf.nn.bidirectional_dynamic_rnn( + cell_fw=cell_fw, + cell_bw=cell_bw, + sequence_length=char_lengths, + inputs=inputs, + time_major=True, + dtype=tf.float32 + ) + + left_right = tf.reshape(left_right, shape=[-1, self.cfg.char_embed_dim]) + + right_left = tf.reshape(right_left, shape=[-1, self.cfg.char_embed_dim]) + + states = tf.concat([left_right, right_left], axis=1) + out_shape = tf.shape(char_ids)[1:3] + out_shape = tf.concat([out_shape, tf.constant( + value=[self.cfg.char_embed_dim * 2], dtype=tf.int32)], axis=0) + return tf.reshape(states, shape=out_shape) diff --git a/examples/trials/weight_sharing/ga_squad/trial.py b/examples/trials/weight_sharing/ga_squad/trial.py new file mode 100644 index 0000000000..bafe1e707a --- /dev/null +++ b/examples/trials/weight_sharing/ga_squad/trial.py @@ -0,0 +1,461 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import argparse +import heapq +import json +import os +import pickle + +import logging +logger = logging.getLogger('ga_squad') + +import numpy as np +from tensorflow.train import init_from_checkpoint + +import graph + +from util import Timer + +import nni +import data +import evaluate +from train_model import * + + +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + + +def get_config(): + ''' + Get config from arument parser. + ''' + parser = argparse.ArgumentParser( + description='This program is using genetic algorithm to search architecture for SQuAD.') + parser.add_argument('--input_file', type=str, + default='./train-v1.1.json', help='input file') + parser.add_argument('--dev_file', type=str, + default='./dev-v1.1.json', help='dev file') + parser.add_argument('--embedding_file', type=str, + default='./glove.840B.300d.txt', help='dev file') + parser.add_argument('--root_path', default='./data/', + type=str, help='Root path of models') + parser.add_argument('--batch_size', type=int, default=64, help='batch size') + parser.add_argument('--save_path', type=str, + default='./save', help='save path dir') + parser.add_argument('--learning_rate', type=float, default=0.0001, + help='set half of original learning rate reload data and train.') + parser.add_argument('--max_epoch', type=int, default=30) + parser.add_argument('--dropout_rate', type=float, + default=0.1, help='dropout_rate') + parser.add_argument('--labelsmoothing', type=float, + default=0.1, help='labelsmoothing') + parser.add_argument('--num_heads', type=int, default=1, help='num_heads') + parser.add_argument('--rnn_units', type=int, default=256, help='rnn_units') + + args = parser.parse_args() + return args + + +def get_id(word_dict, word): + ''' + Return word id. + ''' + if word in word_dict.keys(): + return word_dict[word] + return word_dict[''] + + +def load_embedding(path): + ''' + return embedding for a specif file by given file path. + ''' + EMBEDDING_DIM = 300 + embedding_dict = {} + with open(path, 'r', encoding='utf-8') as file: + pairs = [line.strip('\r\n').split() for line in file.readlines()] + for pair in pairs: + if len(pair) == EMBEDDING_DIM + 1: + embedding_dict[pair[0]] = [float(x) for x in pair[1:]] + logger.debug('embedding_dict size: %d', len(embedding_dict)) + return embedding_dict + + +class MaxQueue: + ''' + Queue for max value. + ''' + + def __init__(self, capacity): + assert capacity > 0, 'queue size must be larger than 0' + self._capacity = capacity + self._entries = [] + + @property + def entries(self): + return self._entries + + @property + def capacity(self): + return self._capacity + + @property + def size(self): + return len(self._entries) + + def clear(self): + self._entries = [] + + def push(self, item): + if self.size < self.capacity: + heapq.heappush(self.entries, item) + else: + heapq.heappushpop(self.entries, item) + + +def find_best_answer_span(left_prob, right_prob, passage_length, max_answer_length): + left = 0 + right = 0 + max_prob = left_prob[0] * right_prob[0] + for i in range(0, passage_length): + left_p = left_prob[i] + for j in range(i, min(i + max_answer_length, passage_length)): + total_prob = left_p * right_prob[j] + if max_prob < total_prob: + left, right, max_prob = i, j, total_prob + return [(max_prob, left, right)] + + +def write_prediction(path, position1_result, position2_result): + import codecs + + with codecs.open(path, 'w', encoding='utf8') as file: + batch_num = len(position1_result) + for i in range(batch_num): + position1_batch = position1_result[i] + position2_batch = position2_result[i] + + for j in range(position1_batch.shape[0]): + file.write(str(position1_batch[j]) + + '\t' + str(position2_batch[j]) + '\n') + + +def find_kbest_answer_span(k, left_prob, right_prob, passage_length, max_answer_length): + if k == 1: + return find_best_answer_span(left_prob, right_prob, passage_length, max_answer_length) + + queue = MaxQueue(k) + for i in range(0, passage_length): + left_p = left_prob[i] + for j in range(i, min(i + max_answer_length, passage_length)): + total_prob = left_p * right_prob[j] + queue.push((total_prob, i, j)) + return list(sorted(queue.entries, key=lambda x: -x[0])) + + +def run_epoch(batches, answer_net, is_training): + if not is_training: + position1_result = [] + position2_result = [] + contexts = [] + ids = [] + + loss_sum = 0 + timer = Timer() + count = 0 + for batch in batches: + used = timer.get_elapsed(False) + count += 1 + qps = batch['qp_pairs'] + question_tokens = [qp['question_tokens'] for qp in qps] + passage_tokens = [qp['passage_tokens'] for qp in qps] + context = [(qp['passage'], qp['passage_tokens']) for qp in qps] + sample_id = [qp['id'] for qp in qps] + + _, query, query_mask, query_lengths = data.get_word_input( + data=question_tokens, word_dict=word_vcb, embed=embed, embed_dim=cfg.word_embed_dim) + _, passage, passage_mask, passage_lengths = data.get_word_input( + data=passage_tokens, word_dict=word_vcb, embed=embed, embed_dim=cfg.word_embed_dim) + + query_char, query_char_lengths = data.get_char_input( + data=question_tokens, char_dict=char_vcb, max_char_length=cfg.max_char_length) + + passage_char, passage_char_lengths = data.get_char_input( + data=passage_tokens, char_dict=char_vcb, max_char_length=cfg.max_char_length) + + if is_training: + answer_begin, answer_end = data.get_answer_begin_end(qps) + + if is_training: + feed_dict = {answer_net.query_word: query, + answer_net.query_mask: query_mask, + answer_net.query_lengths: query_lengths, + answer_net.passage_word: passage, + answer_net.passage_mask: passage_mask, + answer_net.passage_lengths: passage_lengths, + answer_net.query_char_ids: query_char, + answer_net.query_char_lengths: query_char_lengths, + answer_net.passage_char_ids: passage_char, + answer_net.passage_char_lengths: passage_char_lengths, + answer_net.answer_begin: answer_begin, + answer_net.answer_end: answer_end} + loss, _, = sess.run( + [answer_net.loss, answer_net.train_op], feed_dict=feed_dict) + if count % 100 == 0: + logger.debug('%d %g except:%g, loss:%g' % + (count, used, used / count * len(batches), loss)) + loss_sum += loss + else: + feed_dict = {answer_net.query_word: query, + answer_net.query_mask: query_mask, + answer_net.query_lengths: query_lengths, + answer_net.passage_word: passage, + answer_net.passage_mask: passage_mask, + answer_net.passage_lengths: passage_lengths, + answer_net.query_char_ids: query_char, + answer_net.query_char_lengths: query_char_lengths, + answer_net.passage_char_ids: passage_char, + answer_net.passage_char_lengths: passage_char_lengths} + position1, position2 = sess.run( + [answer_net.begin_prob, answer_net.end_prob], feed_dict=feed_dict) + position1_result += position1.tolist() + position2_result += position2.tolist() + contexts += context + ids = np.concatenate((ids, sample_id)) + if count % 100 == 0: + logger.debug('%d %g except:%g' % + (count, used, used / count * len(batches))) + loss = loss_sum / len(batches) + if is_training: + return loss + return loss, position1_result, position2_result, ids, contexts + + +def generate_predict_json(position1_result, position2_result, ids, passage_tokens): + ''' + Generate json by prediction. + ''' + predict_len = len(position1_result) + logger.debug('total prediction num is %s', str(predict_len)) + + answers = {} + for i in range(predict_len): + sample_id = ids[i] + passage, tokens = passage_tokens[i] + kbest = find_best_answer_span( + position1_result[i], position2_result[i], len(tokens), 23) + _, start, end = kbest[0] + answer = passage[tokens[start]['char_begin']:tokens[end]['char_end']] + answers[sample_id] = answer + logger.debug('generate predict done.') + return answers + + +def generate_data(path, tokenizer, char_vcb, word_vcb, is_training=False): + ''' + Generate data + ''' + global root_path + qp_pairs = data.load_from_file(path=path, is_training=is_training) + + tokenized_sent = 0 + # qp_pairs = qp_pairs[:1000]1 + for qp_pair in qp_pairs: + tokenized_sent += 1 + data.tokenize(qp_pair, tokenizer, is_training) + for word in qp_pair['question_tokens']: + word_vcb.add(word['word']) + for char in word['word']: + char_vcb.add(char) + for word in qp_pair['passage_tokens']: + word_vcb.add(word['word']) + for char in word['word']: + char_vcb.add(char) + + max_query_length = max(len(x['question_tokens']) for x in qp_pairs) + max_passage_length = max(len(x['passage_tokens']) for x in qp_pairs) + #min_passage_length = min(len(x['passage_tokens']) for x in qp_pairs) + cfg.max_query_length = max_query_length + cfg.max_passage_length = max_passage_length + + return qp_pairs + + +def train_with_graph(p_graph, qp_pairs, dev_qp_pairs): + ''' + Train a network from a specific graph. + ''' + global sess + with tf.Graph().as_default(): + train_model = GAG(cfg, embed, p_graph) + train_model.build_net(is_training=True) + tf.get_variable_scope().reuse_variables() + dev_model = GAG(cfg, embed, p_graph) + dev_model.build_net(is_training=False) + with tf.Session() as sess: + if restore_path is not None: + restore_mapping = dict(zip(restore_shared, restore_shared)) + logger.debug('init shared variables from {}, restore_scopes: {}'.format(restore_path, restore_shared)) + init_from_checkpoint(restore_path, restore_mapping) + logger.debug('init variables') + logger.debug(sess.run(tf.report_uninitialized_variables())) + init = tf.global_variables_initializer() + sess.run(init) + # writer = tf.summary.FileWriter('%s/graph/'%execution_path, sess.graph) + logger.debug('assign to graph') + + saver = tf.train.Saver() + train_loss = None + bestacc = 0 + patience = 5 + patience_increase = 2 + improvement_threshold = 0.995 + + for epoch in range(max_epoch): + logger.debug('begin to train') + train_batches = data.get_batches(qp_pairs, cfg.batch_size) + train_loss = run_epoch(train_batches, train_model, True) + logger.debug('epoch ' + str(epoch) + + ' loss: ' + str(train_loss)) + dev_batches = list(data.get_batches( + dev_qp_pairs, cfg.batch_size)) + _, position1, position2, ids, contexts = run_epoch( + dev_batches, dev_model, False) + + answers = generate_predict_json( + position1, position2, ids, contexts) + if save_path is not None: + logger.info('save prediction file to {}'.format(save_path)) + with open(os.path.join(save_path, 'epoch%d.prediction' % epoch), 'w') as file: + json.dump(answers, file) + else: + answers = json.dumps(answers) + answers = json.loads(answers) + iter = epoch + 1 + + acc = evaluate.evaluate_with_predictions( + args.dev_file, answers) + + logger.debug('Send intermediate acc: %s', str(acc)) + nni.report_intermediate_result(acc) + + logger.debug('Send intermediate result done.') + + if acc > bestacc: + if acc * improvement_threshold > bestacc: + patience = max(patience, iter * patience_increase) + bestacc = acc + + if save_path is not None: + logger.info('save model & prediction to {}'.format(save_path)) + saver.save(sess, os.path.join(save_path, 'epoch%d.model' % epoch)) + with open(os.path.join(save_path, 'epoch%d.score' % epoch), 'wb') as file: + pickle.dump( + (position1, position2, ids, contexts), file) + logger.debug('epoch %d acc %g bestacc %g' % + (epoch, acc, bestacc)) + if patience <= iter: + break + logger.debug('save done.') + return train_loss, bestacc + + +embed = None +char_vcb = None +tokenizer = None +word_vcb = None + + +def load_data(): + global embed, char_vcb, tokenizer, word_vcb + logger.debug('tokenize data') + tokenizer = data.WhitespaceTokenizer() + + char_set = set() + word_set = set() + logger.debug('generate train data') + qp_pairs = generate_data(input_file, tokenizer, + char_set, word_set, is_training=True) + logger.debug('generate dev data') + dev_qp_pairs = generate_data( + dev_file, tokenizer, char_set, word_set, is_training=False) + logger.debug('generate data done.') + + char_vcb = {char: sample_id for sample_id, char in enumerate(char_set)} + word_vcb = {word: sample_id for sample_id, word in enumerate(word_set)} + + timer.start() + logger.debug('read embedding table') + + cfg.word_embed_dim = 300 + embed = np.zeros((len(word_vcb), cfg.word_embed_dim), dtype=np.float32) + + embedding = load_embedding(args.embedding_file) + for word, sample_id in enumerate(word_vcb): + if word in embedding: + embed[sample_id] = embedding[word] + + # add UNK into dict + unk = np.zeros((1, cfg.word_embed_dim), dtype=np.float32) + embed = np.concatenate((unk, embed), axis=0) + word_vcb = {key: value + 1 for key, value in word_vcb.items()} + + return qp_pairs, dev_qp_pairs + + +if __name__ == '__main__': + try: + args = get_config() + + root_path = os.path.expanduser(args.root_path) + input_file = os.path.expanduser(args.input_file) + dev_file = os.path.expanduser(args.dev_file) + max_epoch = args.max_epoch + + cfg = GAGConfig() + cfg.batch_size = args.batch_size + cfg.learning_rate = float(args.learning_rate) + cfg.dropout = args.dropout_rate + cfg.rnn_units = args.rnn_units + cfg.labelsmoothing = args.labelsmoothing + cfg.num_heads = args.num_heads + timer = Timer() + + qp_pairs, dev_qp_pairs = load_data() + logger.debug('Init finish.') + + original_params = nni.get_next_parameter() + ''' + with open('data.json') as f: + original_params = json.load(f) + ''' + p_graph = graph.graph_loads(original_params['graph']) + save_path = original_params['save_dir'] + os.makedirs(save_path) + restore_path = original_params['restore_dir'] + restore_shared = [hash_id + '/' for hash_id in original_params['shared_id']] if original_params['shared_id'] is not None else [] + ['word_embed', 'char_embed', 'char_encoding/'] + train_loss, best_acc = train_with_graph(p_graph, qp_pairs, dev_qp_pairs) + + logger.debug('Send best acc: %s', str(best_acc)) + nni.report_final_result(best_acc) + logger.debug('Send final result done') + except: + logger.exception('Catch exception in trial.py.') + raise diff --git a/examples/trials/weight_sharing/ga_squad/util.py b/examples/trials/weight_sharing/ga_squad/util.py new file mode 100644 index 0000000000..ac9f363003 --- /dev/null +++ b/examples/trials/weight_sharing/ga_squad/util.py @@ -0,0 +1,76 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +''' +Util Module +''' + +import time + +import tensorflow as tf + + +def shape(tensor): + ''' + Get shape of variable. + Return type is tuple. + ''' + temp_s = tensor.get_shape() + return tuple([temp_s[i].value for i in range(0, len(temp_s))]) + + +def get_variable(name, temp_s): + ''' + Get variable by name. + ''' + return tf.Variable(tf.zeros(temp_s), name=name) + + +def dropout(tensor, drop_prob, is_training): + ''' + Dropout except test. + ''' + if not is_training: + return tensor + return tf.nn.dropout(tensor, 1.0 - drop_prob) + + +class Timer: + ''' + Class Timer is for calculate time. + ''' + def __init__(self): + self.__start = time.time() + + def start(self): + ''' + Start to calculate time. + ''' + self.__start = time.time() + + def get_elapsed(self, restart=True): + ''' + Calculate time span. + ''' + end = time.time() + span = end - self.__start + if restart: + self.__start = end + return span diff --git a/examples/tuners/ga_customer_tuner/customer_tuner.py b/examples/tuners/ga_customer_tuner/customer_tuner.py index 2cfae001e5..699df5eb0e 100644 --- a/examples/tuners/ga_customer_tuner/customer_tuner.py +++ b/examples/tuners/ga_customer_tuner/customer_tuner.py @@ -96,7 +96,7 @@ def generate_parameters(self, parameter_id): temp = json.loads(graph_dumps(indiv.config)) else: random.shuffle(self.population) - if self.population[0].result > self.population[1].result: + if self.population[0].result < self.population[1].result: self.population[0] = self.population[1] indiv = copy.deepcopy(self.population[0]) self.population.pop(1) diff --git a/examples/tuners/weight_sharing/ga_customer_tuner/README.md b/examples/tuners/weight_sharing/ga_customer_tuner/README.md new file mode 100644 index 0000000000..bc7a6f1f84 --- /dev/null +++ b/examples/tuners/weight_sharing/ga_customer_tuner/README.md @@ -0,0 +1,15 @@ +# How to use ga_customer_tuner? +This tuner is a customized tuner which only suitable for trial whose code path is "~/nni/examples/trials/ga_squad", +type `cd ~/nni/examples/trials/ga_squad` and check readme.md to get more information for ga_squad trial. + +# config +If you want to use ga_customer_tuner in your experiment, you could set config file as following format: + +``` +tuner: + codeDir: ~/nni/examples/tuners/ga_customer_tuner + classFileName: customer_tuner.py + className: CustomerTuner + classArgs: + optimize_mode: maximize +``` \ No newline at end of file diff --git a/examples/tuners/weight_sharing/ga_customer_tuner/__init__.py b/examples/tuners/weight_sharing/ga_customer_tuner/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/tuners/weight_sharing/ga_customer_tuner/customer_tuner.py b/examples/tuners/weight_sharing/ga_customer_tuner/customer_tuner.py new file mode 100644 index 0000000000..86520b5220 --- /dev/null +++ b/examples/tuners/weight_sharing/ga_customer_tuner/customer_tuner.py @@ -0,0 +1,224 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +import copy +import json +import logging +import random +import os + +from threading import Event, Lock, current_thread + +from nni.tuner import Tuner + +from graph import Graph, Layer, LayerType, Enum, graph_dumps, graph_loads, unique + +logger = logging.getLogger('ga_customer_tuner') + + +@unique +class OptimizeMode(Enum): + Minimize = 'minimize' + Maximize = 'maximize' + + + + +class Individual(object): + """ + Basic Unit for evolution algorithm + """ + def __init__(self, graph_cfg: Graph = None, info=None, result=None, indiv_id=None): + self.config = graph_cfg + self.result = result + self.info = info + self.indiv_id = indiv_id + self.parent_id = None + self.shared_ids = {layer.hash_id for layer in self.config.layers if layer.is_delete is False} + + def __str__(self): + return "info: " + str(self.info) + ", config :" + str(self.config) + ", result: " + str(self.result) + + def mutation(self, indiv_id: int, graph_cfg: Graph = None, info=None): + self.result = None + if graph_cfg is not None: + self.config = graph_cfg + self.config.mutation() + self.info = info + self.parent_id = self.indiv_id + self.indiv_id = indiv_id + self.shared_ids.intersection_update({layer.hash_id for layer in self.config.layers if layer.is_delete is False}) + + +class CustomerTuner(Tuner): + """ + NAS Tuner using Evolution Algorithm, with weight sharing enabled + """ + def __init__(self, optimize_mode, save_dir_root, population_size=32, graph_max_layer=6, graph_min_layer=3): + self.optimize_mode = OptimizeMode(optimize_mode) + self.indiv_counter = 0 + self.events = [] + self.thread_lock = Lock() + self.save_dir_root = save_dir_root + self.population = self.init_population(population_size, graph_max_layer, graph_min_layer) + assert len(self.population) == population_size + logger.debug('init population done.') + return + + def generate_new_id(self): + """ + generate new id and event hook for new Individual + """ + self.events.append(Event()) + indiv_id = self.indiv_counter + self.indiv_counter += 1 + return indiv_id + + def save_dir(self, indiv_id): + if indiv_id is None: + return None + else: + return os.path.join(self.save_dir_root, str(indiv_id)) + + def init_population(self, population_size, graph_max_layer, graph_min_layer): + """ + initialize populations for evolution tuner + """ + population = [] + graph = Graph(max_layer_num=graph_max_layer, min_layer_num=graph_min_layer, + inputs=[Layer(LayerType.input.value, output=[4, 5], size='x'), Layer(LayerType.input.value, output=[4, 5], size='y')], + output=[Layer(LayerType.output.value, inputs=[4], size='x'), Layer(LayerType.output.value, inputs=[5], size='y')], + hide=[Layer(LayerType.attention.value, inputs=[0, 1], output=[2]), + Layer(LayerType.attention.value, inputs=[1, 0], output=[3])]) + for _ in range(population_size): + graph_tmp = copy.deepcopy(graph) + graph_tmp.mutation() + population.append(Individual(indiv_id=self.generate_new_id(), graph_cfg=graph_tmp, result=None)) + return population + + def generate_parameters(self, parameter_id): + """Returns a set of trial graph config, as a serializable object. + An example configuration: + ```json + { + "shared_id": [ + "4a11b2ef9cb7211590dfe81039b27670", + "370af04de24985e5ea5b3d72b12644c9", + "11f646e9f650f5f3fedc12b6349ec60f", + "0604e5350b9c734dd2d770ee877cfb26", + "6dbeb8b022083396acb721267335f228", + "ba55380d6c84f5caeb87155d1c5fa654" + ], + "graph": { + "layers": [ + ... + { + "hash_id": "ba55380d6c84f5caeb87155d1c5fa654", + "is_delete": false, + "size": "x", + "graph_type": 0, + "output": [ + 6 + ], + "output_size": 1, + "input": [ + 7, + 1 + ], + "input_size": 2 + }, + ... + ] + }, + "restore_dir": "/mnt/nfs/nni/ga_squad/87", + "save_dir": "/mnt/nfs/nni/ga_squad/95" + } + ``` + `restore_dir` means the path in which to load the previous trained model weights. if null, init from stratch. + `save_dir` means the path to save trained model for current trial. + `graph` is the configuration of model network. + Note: each configuration of layers has a `hash_id` property, + which tells tuner & trial code whether to share trained weights or not. + `shared_id` is the hash_id of layers that should be shared with previously trained model. + """ + logger.debug('acquiring lock for param {}'.format(parameter_id)) + self.thread_lock.acquire() + logger.debug('lock for current thread acquired') + if not self.population: + logger.debug("the len of poplution lower than zero.") + raise Exception('The population is empty') + pos = -1 + for i in range(len(self.population)): + if self.population[i].result is None: + pos = i + break + if pos != -1: + indiv = copy.deepcopy(self.population[pos]) + self.population.pop(pos) + graph_param = json.loads(graph_dumps(indiv.config)) + else: + random.shuffle(self.population) + if self.population[0].result < self.population[1].result: + self.population[0] = self.population[1] + indiv = copy.deepcopy(self.population[0]) + self.population.pop(1) + indiv.mutation(indiv_id = self.generate_new_id()) + graph_param = json.loads(graph_dumps(indiv.config)) + param_json = { + 'graph': graph_param, + 'restore_dir': self.save_dir(indiv.parent_id), + 'save_dir': self.save_dir(indiv.indiv_id), + 'shared_id': list(indiv.shared_ids) if indiv.parent_id is not None else None, + } + logger.debug('generate_parameter return value is:') + logger.debug(param_json) + logger.debug('releasing lock') + self.thread_lock.release() + if indiv.parent_id is not None: + logger.debug("new trial {} pending on parent experiment {}".format(indiv.indiv_id, indiv.parent_id)) + self.events[indiv.parent_id].wait() + logger.debug("trial {} ready".format(indiv.indiv_id)) + return param_json + + def receive_trial_result(self, parameter_id, parameters, value): + ''' + Record an observation of the objective function + parameter_id : int + parameters : dict of parameters + value: final metrics of the trial, including reward + ''' + logger.debug('acquiring lock for param {}'.format(parameter_id)) + self.thread_lock.acquire() + logger.debug('lock for current acquired') + reward = self.extract_scalar_reward(value) + if self.optimize_mode is OptimizeMode.Minimize: + reward = -reward + + logger.debug('receive trial result is:\n') + logger.debug(str(parameters)) + logger.debug(str(reward)) + + indiv = Individual(indiv_id=int(os.path.split(parameters['save_dir'])[1]), + graph_cfg=graph_loads(parameters['graph']), result=reward) + self.population.append(indiv) + logger.debug('releasing lock') + self.thread_lock.release() + self.events[indiv.indiv_id].set() + + def update_search_space(self, data): + pass diff --git a/examples/tuners/weight_sharing/ga_customer_tuner/graph.py b/examples/tuners/weight_sharing/ga_customer_tuner/graph.py new file mode 100644 index 0000000000..8e675a06ff --- /dev/null +++ b/examples/tuners/weight_sharing/ga_customer_tuner/graph.py @@ -0,0 +1,336 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +''' +Graph is customed-define class, this module contains related class and function about graph. +''' + + +import copy +import hashlib +import logging +import json +import random +from collections import deque +from enum import Enum, unique +from typing import Iterable + +import numpy as np + +_logger = logging.getLogger('ga_squad_graph') + +@unique +class LayerType(Enum): + ''' + Layer type + ''' + attention = 0 + self_attention = 1 + rnn = 2 + input = 3 + output = 4 + +class Layer(object): + ''' + Layer class, which contains the information of graph. + ''' + def __init__(self, graph_type, inputs=None, output=None, size=None, hash_id=None): + self.input = inputs if inputs is not None else [] + self.output = output if output is not None else [] + self.graph_type = graph_type + self.is_delete = False + self.size = size + self.hash_id = hash_id + if graph_type == LayerType.attention.value: + self.input_size = 2 + self.output_size = 1 + elif graph_type == LayerType.rnn.value: + self.input_size = 1 + self.output_size = 1 + elif graph_type == LayerType.self_attention.value: + self.input_size = 1 + self.output_size = 1 + elif graph_type == LayerType.input.value: + self.input_size = 0 + self.output_size = 1 + if self.hash_id is None: + hasher = hashlib.md5() + hasher.update(np.random.bytes(100)) + self.hash_id = hasher.hexdigest() + elif graph_type == LayerType.output.value: + self.input_size = 1 + self.output_size = 0 + else: + raise ValueError('Unsupported LayerType: {}'.format(graph_type)) + + def update_hash(self, layers: Iterable): + """ + Calculation of `hash_id` of Layer. Which is determined by the properties of itself, and the `hash_id`s of input layers + """ + if self.graph_type == LayerType.input.value: + return + hasher = hashlib.md5() + hasher.update(LayerType(self.graph_type).name.encode('ascii')) + hasher.update(str(self.size).encode('ascii')) + for i in self.input: + if layers[i].hash_id is None: + raise ValueError('Hash id of layer {}: {} not generated!'.format(i, layers[i])) + hasher.update(layers[i].hash_id.encode('ascii')) + self.hash_id = hasher.hexdigest() + + def set_size(self, graph_id, size): + ''' + Set size. + ''' + if self.graph_type == LayerType.attention.value: + if self.input[0] == graph_id: + self.size = size + if self.graph_type == LayerType.rnn.value: + self.size = size + if self.graph_type == LayerType.self_attention.value: + self.size = size + if self.graph_type == LayerType.output.value: + if self.size != size: + return False + return True + + def clear_size(self): + ''' + Clear size + ''' + if self.graph_type == LayerType.attention.value or \ + LayerType.rnn.value or LayerType.self_attention.value: + self.size = None + + def __str__(self): + return 'input:' + str(self.input) + ' output:' + str(self.output) + ' type:' + str(self.graph_type) + ' is_delete:' + str(self.is_delete) + ' size:' + str(self.size) + +def graph_dumps(graph): + ''' + Dump the graph. + ''' + return json.dumps(graph, default=lambda obj: obj.__dict__) + +def graph_loads(graph_json): + ''' + Load graph + ''' + layers = [] + for layer in graph_json['layers']: + layer_info = Layer(layer['graph_type'], layer['input'], layer['output'], layer['size'], layer['hash_id']) + layer_info.is_delete = layer['is_delete'] + _logger.debug('append layer {}'.format(layer_info)) + layers.append(layer_info) + graph = Graph(graph_json['max_layer_num'], graph_json['min_layer_num'], [], [], []) + graph.layers = layers + _logger.debug('graph {} loaded'.format(graph)) + return graph + +class Graph(object): + ''' + Customed Graph class. + ''' + def __init__(self, max_layer_num, min_layer_num, inputs, output, hide): + self.layers = [] + self.max_layer_num = max_layer_num + self.min_layer_num = min_layer_num + assert min_layer_num < max_layer_num + + for layer in inputs: + self.layers.append(layer) + for layer in output: + self.layers.append(layer) + if hide is not None: + for layer in hide: + self.layers.append(layer) + assert self.is_legal() + + def is_topology(self, layers=None): + ''' + valid the topology + ''' + if layers is None: + layers = self.layers + layers_nodle = [] + result = [] + for i, layer in enumerate(layers): + if layer.is_delete is False: + layers_nodle.append(i) + while True: + flag_break = True + layers_toremove = [] + for layer1 in layers_nodle: + flag_arrive = True + for layer2 in layers[layer1].input: + if layer2 in layers_nodle: + flag_arrive = False + if flag_arrive is True: + for layer2 in layers[layer1].output: + # Size is error + if layers[layer2].set_size(layer1, layers[layer1].size) is False: + return False + layers_toremove.append(layer1) + result.append(layer1) + flag_break = False + for layer in layers_toremove: + layers_nodle.remove(layer) + result.append('|') + if flag_break: + break + # There is loop in graph || some layers can't to arrive + if layers_nodle: + return False + return result + + def layer_num(self, layers=None): + ''' + Reutn number of layer. + ''' + if layers is None: + layers = self.layers + layer_num = 0 + for layer in layers: + if layer.is_delete is False and layer.graph_type != LayerType.input.value\ + and layer.graph_type != LayerType.output.value: + layer_num += 1 + return layer_num + + def is_legal(self, layers=None): + ''' + Judge whether is legal for layers + ''' + if layers is None: + layers = self.layers + + for layer in layers: + if layer.is_delete is False: + if len(layer.input) != layer.input_size: + return False + if len(layer.output) < layer.output_size: + return False + + # layer_num <= max_layer_num + if self.layer_num(layers) > self.max_layer_num: + return False + + # There is loop in graph || some layers can't to arrive + if self.is_topology(layers) is False: + return False + + return True + + def update_hash(self): + """ + update hash id of each layer, in topological order/recursively + hash id will be used in weight sharing + """ + _logger.debug('update hash') + layer_in_cnt = [len(layer.input) for layer in self.layers] + topo_queue = deque([i for i, layer in enumerate(self.layers) if not layer.is_delete and layer.graph_type == LayerType.input.value]) + while topo_queue: + layer_i = topo_queue.pop() + self.layers[layer_i].update_hash(self.layers) + for layer_j in self.layers[layer_i].output: + layer_in_cnt[layer_j] -= 1 + if layer_in_cnt[layer_j] == 0: + topo_queue.appendleft(layer_j) + + def mutation(self, only_add=False): + ''' + Mutation for a graph + ''' + types = [] + if self.layer_num() < self.max_layer_num: + types.append(0) + types.append(1) + if self.layer_num() > self.min_layer_num and only_add is False: + types.append(2) + types.append(3) + # 0 : add a layer , delete a edge + # 1 : add a layer , change a edge + # 2 : delete a layer, delete a edge + # 3 : delete a layer, change a edge + graph_type = random.choice(types) + layer_type = random.choice([LayerType.attention.value,\ + LayerType.self_attention.value, LayerType.rnn.value]) + layers = copy.deepcopy(self.layers) + cnt_try = 0 + while True: + layers_in = [] + layers_out = [] + layers_del = [] + for i, layer in enumerate(layers): + if layer.is_delete is False: + if layer.graph_type != LayerType.output.value: + layers_in.append(i) + if layer.graph_type != LayerType.input.value: + layers_out.append(i) + if layer.graph_type != LayerType.output.value\ + and layer.graph_type != LayerType.input.value: + layers_del.append(i) + if graph_type <= 1: + new_id = len(layers) + out = random.choice(layers_out) + inputs = [] + output = [out] + pos = random.randint(0, len(layers[out].input) - 1) + last_in = layers[out].input[pos] + layers[out].input[pos] = new_id + if graph_type == 0: + layers[last_in].output.remove(out) + if graph_type == 1: + layers[last_in].output.remove(out) + layers[last_in].output.append(new_id) + inputs = [last_in] + lay = Layer(graph_type=layer_type, inputs=inputs, output=output) + while len(inputs) < lay.input_size: + layer1 = random.choice(layers_in) + inputs.append(layer1) + layers[layer1].output.append(new_id) + lay.input = inputs + layers.append(lay) + else: + layer1 = random.choice(layers_del) + for layer2 in layers[layer1].output: + layers[layer2].input.remove(layer1) + if graph_type == 2: + random_in = random.choice(layers_in) + else: + random_in = random.choice(layers[layer1].input) + layers[layer2].input.append(random_in) + layers[random_in].output.append(layer2) + for layer2 in layers[layer1].input: + layers[layer2].output.remove(layer1) + layers[layer1].is_delete = True + + if self.is_legal(layers): + self.layers = layers + break + else: + layers = copy.deepcopy(self.layers) + cnt_try += 1 + self.update_hash() + + def __str__(self): + info = "" + for l_id, layer in enumerate(self.layers): + if layer.is_delete is False: + info += 'id:%d ' % l_id + str(layer) + '\n' + return info diff --git a/src/sdk/pynni/nni/common.py b/src/sdk/pynni/nni/common.py index cb21efda64..03fd870c31 100644 --- a/src/sdk/pynni/nni/common.py +++ b/src/sdk/pynni/nni/common.py @@ -63,8 +63,7 @@ def init_logger(logger_file_path): elif env_args.log_dir is not None: logger_file_path = os.path.join(env_args.log_dir, logger_file_path) logger_file = open(logger_file_path, 'w') - - fmt = '[%(asctime)s] %(levelname)s (%(name)s) %(message)s' + fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s' formatter = logging.Formatter(fmt, _time_format) handler = logging.StreamHandler(logger_file) diff --git a/src/sdk/pynni/nni/msg_dispatcher.py b/src/sdk/pynni/nni/msg_dispatcher.py index 4275e58e7e..325befc7d1 100644 --- a/src/sdk/pynni/nni/msg_dispatcher.py +++ b/src/sdk/pynni/nni/msg_dispatcher.py @@ -97,6 +97,7 @@ def handle_initialize(self, data): def handle_request_trial_jobs(self, data): # data: number or trial jobs ids = [_create_parameter_id() for _ in range(data)] + _logger.debug("requesting for generating params of {}".format(ids)) params_list = self.tuner.generate_multiple_parameters(ids) for i, _ in enumerate(params_list): diff --git a/src/sdk/pynni/nni/msg_dispatcher_base.py b/src/sdk/pynni/nni/msg_dispatcher_base.py index bcb8cc1a3a..d0b8c8beb0 100644 --- a/src/sdk/pynni/nni/msg_dispatcher_base.py +++ b/src/sdk/pynni/nni/msg_dispatcher_base.py @@ -19,10 +19,14 @@ # ================================================================================================== #import json_tricks -import os import logging -import json_tricks +import os +from queue import Queue +import sys + from multiprocessing.dummy import Pool as ThreadPool + +import json_tricks from .common import init_logger, multi_thread_enabled from .recoverable import Recoverable from .protocol import CommandType, receive @@ -49,7 +53,7 @@ def run(self): if command is None: break if multi_thread_enabled(): - self.pool.map_async(self.handle_request, [(command, data)]) + self.pool.map_async(self.handle_request_thread, [(command, data)]) else: self.handle_request((command, data)) @@ -59,6 +63,16 @@ def run(self): _logger.info('Terminated by NNI manager') + def handle_request_thread(self, request): + if multi_thread_enabled(): + try: + self.handle_request(request) + except Exception as e: + _logger.exception(str(e)) + sys.exit(-1) + else: + pass + def handle_request(self, request): command, data = request diff --git a/src/sdk/pynni/nni/tuner.py b/src/sdk/pynni/nni/tuner.py index 7d65395425..4dcf705bcf 100644 --- a/src/sdk/pynni/nni/tuner.py +++ b/src/sdk/pynni/nni/tuner.py @@ -48,6 +48,7 @@ def generate_multiple_parameters(self, parameter_id_list): result = [] for parameter_id in parameter_id_list: try: + _logger.debug("generating param for {}".format(parameter_id)) res = self.generate_parameters(parameter_id) except nni.NoMoreTrialError: return result diff --git a/test/async_sharing_test/config.yml b/test/async_sharing_test/config.yml new file mode 100644 index 0000000000..8cefad3c1a --- /dev/null +++ b/test/async_sharing_test/config.yml @@ -0,0 +1,25 @@ +authorName: default +experimentName: example_weight_sharing +trialConcurrency: 3 +maxExecDuration: 1h +maxTrialNum: 10 +#choice: local, remote, pai +trainingServicePlatform: remote +#choice: true, false +useAnnotation: false +multiThread: true +tuner: + codeDir: . + classFileName: simple_tuner.py + className: SimpleTuner +trial: + command: python3 main.py + codeDir: . + gpuNum: 0 +machineList: + - ip: 10.10.10.10 + username: bob + passwd: bob123 + - ip: 10.10.10.11 + username: bob + passwd: bob123 diff --git a/test/async_sharing_test/main.py b/test/async_sharing_test/main.py new file mode 100644 index 0000000000..4c32ea51ca --- /dev/null +++ b/test/async_sharing_test/main.py @@ -0,0 +1,56 @@ +""" +Test code for weight sharing +need NFS setup and mounted as `/mnt/nfs/nni` +""" + +import hashlib +import os +import random +import time + +import nni + + +def generate_rand_file(fl_name): + """ + generate random file and write to `fl_name` + """ + fl_size = random.randint(1024, 102400) + fl_dir = os.path.split(fl_name)[0] + if not os.path.exists(fl_dir): + os.makedirs(fl_dir) + with open(fl_name, 'wb') as fout: + fout.write(os.urandom(fl_size)) + + +def check_sum(fl_name, tid=None): + """ + compute checksum for generated file of `fl_name` + """ + hasher = hashlib.md5() + with open(fl_name, 'rb') as fin: + for chunk in iter(lambda: fin.read(4096), b""): + hasher.update(chunk) + ret = hasher.hexdigest() + if tid is not None: + ret = ret + str(tid) + return ret + + +if __name__ == '__main__': + nfs_path = '/mnt/nfs/nni' + params = nni.get_next_parameter() + print(params) + if params['prev_id'] == 0: + model_file = os.path.join(nfs_path, str(params['id']), 'model.dat') + time.sleep(10) + generate_rand_file(model_file) + nni.report_final_result({ + 'checksum': check_sum(model_file), + 'path': model_file + }) + else: + model_file = params['prev_path'] + nni.report_final_result({ + 'checksum': check_sum(model_file, params['prev_id']) + }) diff --git a/test/async_sharing_test/simple_tuner.py b/test/async_sharing_test/simple_tuner.py new file mode 100644 index 0000000000..57c39cbe3b --- /dev/null +++ b/test/async_sharing_test/simple_tuner.py @@ -0,0 +1,65 @@ +""" +SimpleTuner for Weight Sharing +""" + +import logging + +from threading import Event, Lock +from nni.tuner import Tuner + +_logger = logging.getLogger('WeightSharingTuner') + + +class SimpleTuner(Tuner): + """ + simple tuner, test for weight sharing + """ + + def __init__(self): + super(SimpleTuner, self).__init__() + self.trial_meta = {} + self.f_id = None # father + self.sig_event = Event() + self.thread_lock = Lock() + + def generate_parameters(self, parameter_id): + if self.f_id is None: + self.thread_lock.acquire() + self.f_id = parameter_id + self.trial_meta[parameter_id] = { + 'prev_id': 0, + 'id': parameter_id, + 'checksum': None, + 'path': '', + } + _logger.info('generate parameter for father trial %s' % + parameter_id) + self.thread_lock.release() + return { + 'prev_id': 0, + 'id': parameter_id, + } + else: + self.sig_event.wait() + self.thread_lock.acquire() + self.trial_meta[parameter_id] = { + 'id': parameter_id, + 'prev_id': self.f_id, + 'prev_path': self.trial_meta[self.f_id]['path'] + } + self.thread_lock.release() + return self.trial_meta[parameter_id] + + def receive_trial_result(self, parameter_id, parameters, reward): + self.thread_lock.acquire() + if parameter_id == self.f_id: + self.trial_meta[parameter_id]['checksum'] = reward['checksum'] + self.trial_meta[parameter_id]['path'] = reward['path'] + self.sig_event.set() + else: + if reward['checksum'] != self.trial_meta[self.f_id]['checksum'] + str(self.f_id): + raise ValueError("Inconsistency in weight sharing!!!") + self.thread_lock.release() + + def update_search_space(self, search_space): + pass diff --git a/tools/nni_cmd/launcher.py b/tools/nni_cmd/launcher.py index 7fa5f8b974..485329d3d2 100644 --- a/tools/nni_cmd/launcher.py +++ b/tools/nni_cmd/launcher.py @@ -203,6 +203,8 @@ def set_experiment(experiment_config, mode, port, config_file_name): request_data['description'] = experiment_config['description'] if experiment_config.get('multiPhase'): request_data['multiPhase'] = experiment_config.get('multiPhase') + if experiment_config.get('multiThread'): + request_data['multiThread'] = experiment_config.get('multiThread') if experiment_config.get('advisor'): request_data['advisor'] = experiment_config['advisor'] else: