From e470eef3e5771fd248851a3e173521b3d6f0f7f3 Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Wed, 7 Aug 2019 09:06:00 +0800 Subject: [PATCH] PPO tuner for NAS, supports NNI's NAS interface (#1380) * ppo tuner --- docs/en_US/Tuner/BuiltinTuner.md | 30 + examples/trials/mnist-nas/config_ppo.yml | 20 + .../rest_server/restValidationSchemas.ts | 2 +- src/sdk/pynni/nni/constants.py | 4 +- src/sdk/pynni/nni/msg_dispatcher.py | 7 +- src/sdk/pynni/nni/ppo_tuner/__init__.py | 0 src/sdk/pynni/nni/ppo_tuner/distri.py | 198 ++++++ src/sdk/pynni/nni/ppo_tuner/model.py | 166 +++++ src/sdk/pynni/nni/ppo_tuner/policy.py | 219 +++++++ src/sdk/pynni/nni/ppo_tuner/ppo_tuner.py | 589 ++++++++++++++++++ src/sdk/pynni/nni/ppo_tuner/requirements.txt | 3 + src/sdk/pynni/nni/ppo_tuner/util.py | 266 ++++++++ tools/nni_cmd/config_schema.py | 11 + tools/nni_cmd/constants.py | 3 +- 14 files changed, 1514 insertions(+), 4 deletions(-) create mode 100644 examples/trials/mnist-nas/config_ppo.yml create mode 100644 src/sdk/pynni/nni/ppo_tuner/__init__.py create mode 100644 src/sdk/pynni/nni/ppo_tuner/distri.py create mode 100644 src/sdk/pynni/nni/ppo_tuner/model.py create mode 100644 src/sdk/pynni/nni/ppo_tuner/policy.py create mode 100644 src/sdk/pynni/nni/ppo_tuner/ppo_tuner.py create mode 100644 src/sdk/pynni/nni/ppo_tuner/requirements.txt create mode 100644 src/sdk/pynni/nni/ppo_tuner/util.py diff --git a/docs/en_US/Tuner/BuiltinTuner.md b/docs/en_US/Tuner/BuiltinTuner.md index 46dc8da60c..2bafaeb7ab 100644 --- a/docs/en_US/Tuner/BuiltinTuner.md +++ b/docs/en_US/Tuner/BuiltinTuner.md @@ -20,6 +20,7 @@ Currently we support the following algorithms: |[__Metis Tuner__](#MetisTuner)|Metis offers the following benefits when it comes to tuning parameters: While most tools only predict the optimal configuration, Metis gives you two outputs: (a) current prediction of optimal configuration, and (b) suggestion for the next trial. No more guesswork. While most tools assume training datasets do not have noisy data, Metis actually tells you if you need to re-sample a particular hyper-parameter. [Reference Paper](https://www.microsoft.com/en-us/research/publication/metis-robustly-tuning-tail-latencies-cloud-systems/)| |[__BOHB__](#BOHB)|BOHB is a follow-up work of Hyperband. It targets the weakness of Hyperband that new configurations are generated randomly without leveraging finished trials. For the name BOHB, HB means Hyperband, BO means Bayesian Optimization. BOHB leverages finished trials by building multiple TPE models, a proportion of new configurations are generated through these models. [Reference Paper](https://arxiv.org/abs/1807.01774)| |[__GP Tuner__](#GPTuner)|Gaussian Process Tuner is a sequential model-based optimization (SMBO) approach with Gaussian Process as the surrogate. [Reference Paper](https://papers.nips.cc/paper/4443-algorithms-for-hyper-parameter-optimization.pdf), [Github Repo](https://github.com/fmfn/BayesianOptimization)| +|[__PPO Tuner__](#PPOTuner)|PPO Tuner is an Reinforcement Learning tuner based on PPO algorithm. [Reference Paper](https://arxiv.org/abs/1707.06347)| ## Usage of Built-in Tuners @@ -409,3 +410,32 @@ tuner: selection_num_warm_up: 100000 selection_num_starting_points: 250 ``` + + + +![](https://placehold.it/15/1589F0/000000?text=+) `PPO Tuner` + +> Built-in Tuner Name: **PPOTuner** + +Note that the only acceptable type of search space is `mutable_layer`. `optional_input_size` can only be 0, 1, or [0, 1]. + +**Suggested scenario** + +PPOTuner is a Reinforcement Learning tuner based on PPO algorithm. When you are using NNI NAS interface in your trial code to do neural architecture search, PPOTuner is recommended. It has relatively high data efficiency but is suggested when you have large amount of computation resource. You could try it on very simple task, such as the [mnist-nas](https://github.com/microsoft/nni/tree/master/examples/trials/mnist-nas) example. [Detailed Description](./PPOTuner.md) + +**Requirement of classArg** + +* **optimize_mode** (*'maximize' or 'minimize'*) - If 'maximize', the tuner will target to maximize metrics. If 'minimize', the tuner will target to minimize metrics. +* **trials_per_update** (*int, optional, default = 20*) - The number of trials to be used for one update. This number is recommended to be larger than `trialConcurrency` and `trialConcurrency` be a aliquot devisor of `trials_per_update`. +* **epochs_per_update** (*int, optional, default = 4*) - The number of epochs for one update. +* **minibatch_size** (*int, optional, default = 4*) - Mini-batch size (i.e., number of trials for a mini-batch) for the update. Note that, trials_per_update should be divisible of minibatch_size. + +**Usage example** + +```yaml +# config.yml +tuner: + builtinTunerName: PPOTuner + classArgs: + optimize_mode: maximize +``` \ No newline at end of file diff --git a/examples/trials/mnist-nas/config_ppo.yml b/examples/trials/mnist-nas/config_ppo.yml new file mode 100644 index 0000000000..ba6a1cbd1e --- /dev/null +++ b/examples/trials/mnist-nas/config_ppo.yml @@ -0,0 +1,20 @@ +authorName: NNI-example +experimentName: example_mnist +trialConcurrency: 1 +maxExecDuration: 100h +maxTrialNum: 10000 +#choice: local, remote, pai +trainingServicePlatform: local +#choice: true, false +useAnnotation: true +tuner: + #choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner + #SMAC (SMAC should be installed through nnictl) + #codeDir: ~/nni/nni/examples/tuners/random_nas_tuner + builtinTunerName: PPOTuner + classArgs: + optimize_mode: maximize +trial: + command: python3 mnist.py + codeDir: . + gpuNum: 0 diff --git a/src/nni_manager/rest_server/restValidationSchemas.ts b/src/nni_manager/rest_server/restValidationSchemas.ts index 19f88f11af..9328cdd696 100644 --- a/src/nni_manager/rest_server/restValidationSchemas.ts +++ b/src/nni_manager/rest_server/restValidationSchemas.ts @@ -167,7 +167,7 @@ export namespace ValidationSchemas { checkpointDir: joi.string().allow('') }), tuner: joi.object({ - builtinTunerName: joi.string().valid('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner', 'GridSearch', 'NetworkMorphism', 'MetisTuner', 'GPTuner'), + builtinTunerName: joi.string().valid('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner', 'GridSearch', 'NetworkMorphism', 'MetisTuner', 'GPTuner', 'PPOTuner'), codeDir: joi.string(), classFileName: joi.string(), className: joi.string(), diff --git a/src/sdk/pynni/nni/constants.py b/src/sdk/pynni/nni/constants.py index ab726baa1b..5fc515da7b 100644 --- a/src/sdk/pynni/nni/constants.py +++ b/src/sdk/pynni/nni/constants.py @@ -30,7 +30,8 @@ 'NetworkMorphism': 'nni.networkmorphism_tuner.networkmorphism_tuner', 'Curvefitting': 'nni.curvefitting_assessor.curvefitting_assessor', 'MetisTuner': 'nni.metis_tuner.metis_tuner', - 'GPTuner': 'nni.gp_tuner.gp_tuner' + 'GPTuner': 'nni.gp_tuner.gp_tuner', + 'PPOTuner': 'nni.ppo_tuner.ppo_tuner' } ClassName = { @@ -44,6 +45,7 @@ 'NetworkMorphism':'NetworkMorphismTuner', 'MetisTuner':'MetisTuner', 'GPTuner':'GPTuner', + 'PPOTuner': 'PPOTuner', 'Medianstop': 'MedianstopAssessor', 'Curvefitting': 'CurvefittingAssessor' diff --git a/src/sdk/pynni/nni/msg_dispatcher.py b/src/sdk/pynni/nni/msg_dispatcher.py index 0beda3f154..d6b0af2cb7 100644 --- a/src/sdk/pynni/nni/msg_dispatcher.py +++ b/src/sdk/pynni/nni/msg_dispatcher.py @@ -100,11 +100,16 @@ def handle_initialize(self, data): self.tuner.update_search_space(data) send(CommandType.Initialized, '') + def send_trial_callback(self, id, params): + """For tuner to issue trial config when the config is generated + """ + send(CommandType.NewTrialJob, _pack_parameter(id, params)) + def handle_request_trial_jobs(self, data): # data: number or trial jobs ids = [_create_parameter_id() for _ in range(data)] _logger.debug("requesting for generating params of {}".format(ids)) - params_list = self.tuner.generate_multiple_parameters(ids) + params_list = self.tuner.generate_multiple_parameters(ids, st_callback=self.send_trial_callback) for i, _ in enumerate(params_list): send(CommandType.NewTrialJob, _pack_parameter(ids[i], params_list[i])) diff --git a/src/sdk/pynni/nni/ppo_tuner/__init__.py b/src/sdk/pynni/nni/ppo_tuner/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sdk/pynni/nni/ppo_tuner/distri.py b/src/sdk/pynni/nni/ppo_tuner/distri.py new file mode 100644 index 0000000000..4666acc2da --- /dev/null +++ b/src/sdk/pynni/nni/ppo_tuner/distri.py @@ -0,0 +1,198 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +""" +functions for sampling from hidden state +""" + +import tensorflow as tf + +from .util import fc + + +class Pd: + """ + A particular probability distribution + """ + def flatparam(self): + raise NotImplementedError + def mode(self): + raise NotImplementedError + def neglogp(self, x): + # Usually it's easier to define the negative logprob + raise NotImplementedError + def kl(self, other): + raise NotImplementedError + def entropy(self): + raise NotImplementedError + def sample(self): + raise NotImplementedError + def logp(self, x): + return - self.neglogp(x) + def get_shape(self): + return self.flatparam().shape + @property + def shape(self): + return self.get_shape() + def __getitem__(self, idx): + return self.__class__(self.flatparam()[idx]) + +class PdType: + """ + Parametrized family of probability distributions + """ + def pdclass(self): + raise NotImplementedError + def pdfromflat(self, flat, mask, nsteps, size, is_act_model): + return self.pdclass()(flat, mask, nsteps, size, is_act_model) + def pdfromlatent(self, latent_vector, init_scale, init_bias): + raise NotImplementedError + def param_shape(self): + raise NotImplementedError + def sample_shape(self): + raise NotImplementedError + def sample_dtype(self): + raise NotImplementedError + + def param_placeholder(self, prepend_shape, name=None): + return tf.placeholder(dtype=tf.float32, shape=prepend_shape+self.param_shape(), name=name) + def sample_placeholder(self, prepend_shape, name=None): + return tf.placeholder(dtype=self.sample_dtype(), shape=prepend_shape+self.sample_shape(), name=name) + +class CategoricalPd(Pd): + """ + categorical prossibility distribution + """ + def __init__(self, logits, mask_npinf, nsteps, size, is_act_model): + self.logits = logits + self.mask_npinf = mask_npinf + self.nsteps = nsteps + self.size = size + self.is_act_model = is_act_model + def flatparam(self): + return self.logits + def mode(self): + return tf.argmax(self.logits, axis=-1) + + @property + def mean(self): + return tf.nn.softmax(self.logits) + def neglogp(self, x): + """ + return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x) + Note: we can't use sparse_softmax_cross_entropy_with_logits because + the implementation does not allow second-order derivatives... + """ + if x.dtype in {tf.uint8, tf.int32, tf.int64}: + # one-hot encoding + x_shape_list = x.shape.as_list() + logits_shape_list = self.logits.get_shape().as_list()[:-1] + for xs, ls in zip(x_shape_list, logits_shape_list): + if xs is not None and ls is not None: + assert xs == ls, 'shape mismatch: {} in x vs {} in logits'.format(xs, ls) + + x = tf.one_hot(x, self.logits.get_shape().as_list()[-1]) + else: + # already encoded + assert x.shape.as_list() == self.logits.shape.as_list() + + return tf.nn.softmax_cross_entropy_with_logits_v2( + logits=self.logits, + labels=x) + + def kl(self, other): + """kl""" + a0 = self.logits - tf.reduce_max(self.logits, axis=-1, keepdims=True) + a1 = other.logits - tf.reduce_max(other.logits, axis=-1, keepdims=True) + ea0 = tf.exp(a0) + ea1 = tf.exp(a1) + z0 = tf.reduce_sum(ea0, axis=-1, keepdims=True) + z1 = tf.reduce_sum(ea1, axis=-1, keepdims=True) + p0 = ea0 / z0 + return tf.reduce_sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), axis=-1) + + def entropy(self): + """compute entropy""" + a0 = self.logits - tf.reduce_max(self.logits, axis=-1, keepdims=True) + ea0 = tf.exp(a0) + z0 = tf.reduce_sum(ea0, axis=-1, keepdims=True) + p0 = ea0 / z0 + return tf.reduce_sum(p0 * (tf.log(z0) - a0), axis=-1) + + def sample(self): + """sample from logits""" + if not self.is_act_model: + re_res = tf.reshape(self.logits, [-1, self.nsteps, self.size]) + masked_res = tf.math.add(re_res, self.mask_npinf) + re_masked_res = tf.reshape(masked_res, [-1, self.size]) + + u = tf.random_uniform(tf.shape(re_masked_res), dtype=self.logits.dtype) + return tf.argmax(re_masked_res - tf.log(-tf.log(u)), axis=-1) + else: + u = tf.random_uniform(tf.shape(self.logits), dtype=self.logits.dtype) + return tf.argmax(self.logits - tf.log(-tf.log(u)), axis=-1) + + @classmethod + def fromflat(cls, flat): + return cls(flat) + +class CategoricalPdType(PdType): + """ + to create CategoricalPd + """ + def __init__(self, ncat, nsteps, np_mask, is_act_model): + self.ncat = ncat + self.nsteps = nsteps + self.np_mask = np_mask + self.is_act_model = is_act_model + def pdclass(self): + return CategoricalPd + + def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0): + """add fc and create CategoricalPd""" + pdparam, mask, mask_npinf = _matching_fc(latent_vector, 'pi', self.ncat, self.nsteps, + init_scale=init_scale, init_bias=init_bias, + np_mask=self.np_mask, is_act_model=self.is_act_model) + return self.pdfromflat(pdparam, mask_npinf, self.nsteps, self.ncat, self.is_act_model), pdparam, mask, mask_npinf + + def param_shape(self): + return [self.ncat] + def sample_shape(self): + return [] + def sample_dtype(self): + return tf.int32 + +def _matching_fc(tensor, name, size, nsteps, init_scale, init_bias, np_mask, is_act_model): + """ + add fc op, and add mask op when not in action mode + """ + if tensor.shape[-1] == size: + assert False + return tensor + else: + mask = tf.get_variable("act_mask", dtype=tf.float32, initializer=np_mask[0], trainable=False) + mask_npinf = tf.get_variable("act_mask_npinf", dtype=tf.float32, initializer=np_mask[1], trainable=False) + res = fc(tensor, name, size, init_scale=init_scale, init_bias=init_bias) + if not is_act_model: + re_res = tf.reshape(res, [-1, nsteps, size]) + masked_res = tf.math.multiply(re_res, mask) + re_masked_res = tf.reshape(masked_res, [-1, size]) + return re_masked_res, mask, mask_npinf + else: + return res, mask, mask_npinf diff --git a/src/sdk/pynni/nni/ppo_tuner/model.py b/src/sdk/pynni/nni/ppo_tuner/model.py new file mode 100644 index 0000000000..330f10369d --- /dev/null +++ b/src/sdk/pynni/nni/ppo_tuner/model.py @@ -0,0 +1,166 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +""" +the main model of policy/value network +""" + +import tensorflow as tf + +from .util import initialize, get_session + +class Model: + """ + We use this object to : + __init__: + - Creates the step_model + - Creates the train_model + + train(): + - Make the training part (feedforward and retropropagation of gradients) + + save/load(): + - Save load the model + """ + def __init__(self, *, policy, nbatch_act, nbatch_train, + nsteps, ent_coef, vf_coef, max_grad_norm, microbatch_size=None, np_mask=None): + """ + init + """ + self.sess = sess = get_session() + + with tf.variable_scope('ppo2_model', reuse=tf.AUTO_REUSE): + # CREATE OUR TWO MODELS + # act_model that is used for sampling + act_model = policy(nbatch_act, 1, sess, np_mask=np_mask, is_act_model=True) + + # Train model for training + if microbatch_size is None: + train_model = policy(nbatch_train, nsteps, sess, np_mask=np_mask, is_act_model=False) + else: + train_model = policy(microbatch_size, nsteps, sess, np_mask=np_mask, is_act_model=False) + + # CREATE THE PLACEHOLDERS + self.A = A = train_model.pdtype.sample_placeholder([None]) + self.ADV = ADV = tf.placeholder(tf.float32, [None]) + self.R = R = tf.placeholder(tf.float32, [None]) + # Keep track of old actor + self.OLDNEGLOGPAC = OLDNEGLOGPAC = tf.placeholder(tf.float32, [None]) + # Keep track of old critic + self.OLDVPRED = OLDVPRED = tf.placeholder(tf.float32, [None]) + self.LR = LR = tf.placeholder(tf.float32, []) + # Cliprange + self.CLIPRANGE = CLIPRANGE = tf.placeholder(tf.float32, []) + + neglogpac = train_model.pd.neglogp(A) + + # Calculate the entropy + # Entropy is used to improve exploration by limiting the premature convergence to suboptimal policy. + entropy = tf.reduce_mean(train_model.pd.entropy()) + + # CALCULATE THE LOSS + # Total loss = Policy gradient loss - entropy * entropy coefficient + Value coefficient * value loss + + # Clip the value to reduce variability during Critic training + # Get the predicted value + vpred = train_model.vf + vpredclipped = OLDVPRED + tf.clip_by_value(train_model.vf - OLDVPRED, - CLIPRANGE, CLIPRANGE) + # Unclipped value + vf_losses1 = tf.square(vpred - R) + # Clipped value + vf_losses2 = tf.square(vpredclipped - R) + + vf_loss = .5 * tf.reduce_mean(tf.maximum(vf_losses1, vf_losses2)) + + # Calculate ratio (pi current policy / pi old policy) + ratio = tf.exp(OLDNEGLOGPAC - neglogpac) + + # Defining Loss = - J is equivalent to max J + pg_losses = -ADV * ratio + + pg_losses2 = -ADV * tf.clip_by_value(ratio, 1.0 - CLIPRANGE, 1.0 + CLIPRANGE) + + # Final PG loss + pg_loss = tf.reduce_mean(tf.maximum(pg_losses, pg_losses2)) + approxkl = .5 * tf.reduce_mean(tf.square(neglogpac - OLDNEGLOGPAC)) + clipfrac = tf.reduce_mean(tf.to_float(tf.greater(tf.abs(ratio - 1.0), CLIPRANGE))) + + # Total loss + loss = pg_loss - entropy * ent_coef + vf_loss * vf_coef + + # UPDATE THE PARAMETERS USING LOSS + # 1. Get the model parameters + params = tf.trainable_variables('ppo2_model') + # 2. Build our trainer + self.trainer = tf.train.AdamOptimizer(learning_rate=LR, epsilon=1e-5) + # 3. Calculate the gradients + grads_and_var = self.trainer.compute_gradients(loss, params) + grads, var = zip(*grads_and_var) + + if max_grad_norm is not None: + # Clip the gradients (normalize) + grads, _grad_norm = tf.clip_by_global_norm(grads, max_grad_norm) + grads_and_var = list(zip(grads, var)) + # zip aggregate each gradient with parameters associated + # For instance zip(ABCD, xyza) => Ax, By, Cz, Da + + self.grads = grads + self.var = var + self._train_op = self.trainer.apply_gradients(grads_and_var) + self.loss_names = ['policy_loss', 'value_loss', 'policy_entropy', 'approxkl', 'clipfrac'] + self.stats_list = [pg_loss, vf_loss, entropy, approxkl, clipfrac] + + + self.train_model = train_model + self.act_model = act_model + self.step = act_model.step + self.value = act_model.value + self.initial_state = act_model.initial_state + + initialize() + + def train(self, lr, cliprange, obs, returns, masks, actions, values, neglogpacs, states=None): + """ + train the model. + Here we calculate advantage A(s,a) = R + yV(s') - V(s) + Returns = R + yV(s') + """ + advs = returns - values + + # Normalize the advantages + advs = (advs - advs.mean()) / (advs.std() + 1e-8) + + td_map = { + self.train_model.X : obs, + self.A : actions, + self.ADV : advs, + self.R : returns, + self.LR : lr, + self.CLIPRANGE : cliprange, + self.OLDNEGLOGPAC : neglogpacs, + self.OLDVPRED : values + } + if states is not None: + td_map[self.train_model.S] = states + td_map[self.train_model.M] = masks + + return self.sess.run( + self.stats_list + [self._train_op], + td_map + )[:-1] diff --git a/src/sdk/pynni/nni/ppo_tuner/policy.py b/src/sdk/pynni/nni/ppo_tuner/policy.py new file mode 100644 index 0000000000..65e2db414e --- /dev/null +++ b/src/sdk/pynni/nni/ppo_tuner/policy.py @@ -0,0 +1,219 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +""" +build policy/value network from model +""" + +import tensorflow as tf + +from .distri import CategoricalPdType +from .util import lstm_model, fc, observation_placeholder, adjust_shape + + +class PolicyWithValue: + """ + Encapsulates fields and methods for RL policy and value function estimation with shared parameters + """ + + def __init__(self, env, observations, latent, estimate_q=False, vf_latent=None, sess=None, np_mask=None, is_act_model=False, **tensors): + """ + Parameters: + ---------- + env: RL environment + observations: tensorflow placeholder in which the observations will be fed + latent: latent state from which policy distribution parameters should be inferred + vf_latent: latent state from which value function should be inferred (if None, then latent is used) + sess: tensorflow session to run calculations in (if None, default session is used) + **tensors: tensorflow tensors for additional attributes such as state or mask + """ + + self.X = observations + self.state = tf.constant([]) + self.initial_state = None + self.__dict__.update(tensors) + + vf_latent = vf_latent if vf_latent is not None else latent + + vf_latent = tf.layers.flatten(vf_latent) + latent = tf.layers.flatten(latent) + + # Based on the action space, will select what probability distribution type + self.np_mask = np_mask + self.pdtype = CategoricalPdType(env.action_space.n, env.nsteps, np_mask, is_act_model) + + self.act_latent = latent + self.nh = env.action_space.n + + self.pd, self.pi, self.mask, self.mask_npinf = self.pdtype.pdfromlatent(latent, init_scale=0.01) + + # Take an action + self.action = self.pd.sample() + + # Calculate the neg log of our probability + self.neglogp = self.pd.neglogp(self.action) + self.sess = sess or tf.get_default_session() + + assert estimate_q is False + self.vf = fc(vf_latent, 'vf', 1) + self.vf = self.vf[:, 0] + + if is_act_model: + self._build_model_for_step() + + def _evaluate(self, variables, observation, **extra_feed): + sess = self.sess + feed_dict = {self.X: adjust_shape(self.X, observation)} + for inpt_name, data in extra_feed.items(): + if inpt_name in self.__dict__.keys(): + inpt = self.__dict__[inpt_name] + if isinstance(inpt, tf.Tensor) and inpt._op.type == 'Placeholder': + feed_dict[inpt] = adjust_shape(inpt, data) + + return sess.run(variables, feed_dict) + + def _build_model_for_step(self): + # multiply with weight and apply mask on self.act_latent to generate + self.act_step = step = tf.placeholder(shape=(), dtype=tf.int64, name='act_step') + with tf.variable_scope('pi', reuse=tf.AUTO_REUSE): + from .util import ortho_init + nin = self.act_latent.get_shape()[1].value + w = tf.get_variable("w", [nin, self.nh], initializer=ortho_init(0.01)) + b = tf.get_variable("b", [self.nh], initializer=tf.constant_initializer(0.0)) + logits = tf.matmul(self.act_latent, w)+b + piece = tf.slice(self.mask, [step, 0], [1, self.nh]) + re_piece = tf.reshape(piece, [-1]) + masked_logits = tf.math.multiply(logits, re_piece) + + npinf_piece = tf.slice(self.mask_npinf, [step, 0], [1, self.nh]) + re_npinf_piece = tf.reshape(npinf_piece, [-1]) + + def sample(logits, mask_npinf): + new_logits = tf.math.add(logits, mask_npinf) + u = tf.random_uniform(tf.shape(new_logits), dtype=logits.dtype) + return tf.argmax(new_logits - tf.log(-tf.log(u)), axis=-1) + + def neglogp(logits, x): + # return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x) + # Note: we can't use sparse_softmax_cross_entropy_with_logits because + # the implementation does not allow second-order derivatives... + if x.dtype in {tf.uint8, tf.int32, tf.int64}: + # one-hot encoding + x_shape_list = x.shape.as_list() + logits_shape_list = logits.get_shape().as_list()[:-1] + for xs, ls in zip(x_shape_list, logits_shape_list): + if xs is not None and ls is not None: + assert xs == ls, 'shape mismatch: {} in x vs {} in logits'.format(xs, ls) + + x = tf.one_hot(x, logits.get_shape().as_list()[-1]) + else: + # already encoded + assert x.shape.as_list() == logits.shape.as_list() + + return tf.nn.softmax_cross_entropy_with_logits_v2( + logits=logits, + labels=x) + + self.act_action = sample(masked_logits, re_npinf_piece) + self.act_neglogp = neglogp(masked_logits, self.act_action) + + + def step(self, step, observation, **extra_feed): + """ + Compute next action(s) given the observation(s) + + Parameters: + ---------- + observation: observation data (either single or a batch) + **extra_feed: additional data such as state or mask (names of the arguments should match the ones in constructor, see __init__) + + Returns: + ------- + (action, value estimate, next state, negative log likelihood of the action under current policy parameters) tuple + """ + extra_feed['act_step'] = step + a, v, state, neglogp = self._evaluate([self.act_action, self.vf, self.state, self.act_neglogp], observation, **extra_feed) + if state.size == 0: + state = None + return a, v, state, neglogp + + def value(self, ob, *args, **kwargs): + """ + Compute value estimate(s) given the observation(s) + + Parameters: + ---------- + observation: observation data (either single or a batch) + **extra_feed: additional data such as state or mask (names of the arguments should match the ones in constructor, see __init__) + + Returns: + ------- + value estimate + """ + return self._evaluate(self.vf, ob, *args, **kwargs) + + +def build_lstm_policy(model_config, value_network=None, estimate_q=False, **policy_kwargs): + """ + build lstm policy and value network, they share the same lstm network. + the parameters all use their default values. + """ + policy_network = lstm_model(**policy_kwargs) + + def policy_fn(nbatch=None, nsteps=None, sess=None, observ_placeholder=None, np_mask=None, is_act_model=False): + ob_space = model_config.observation_space + + X = observ_placeholder if observ_placeholder is not None else observation_placeholder(ob_space, batch_size=nbatch) + + extra_tensors = {} + + # encode_observation is not necessary anymore as we use embedding_lookup + encoded_x = X + + with tf.variable_scope('pi', reuse=tf.AUTO_REUSE): + policy_latent = policy_network(encoded_x, 1, model_config.observation_space.n) + if isinstance(policy_latent, tuple): + policy_latent, recurrent_tensors = policy_latent + + if recurrent_tensors is not None: + # recurrent architecture, need a few more steps + nenv = nbatch // nsteps + assert nenv > 0, 'Bad input for recurrent policy: batch size {} smaller than nsteps {}'.format(nbatch, nsteps) + policy_latent, recurrent_tensors = policy_network(encoded_x, nenv, model_config.observation_space.n) + extra_tensors.update(recurrent_tensors) + + _v_net = value_network + + assert _v_net is None or _v_net == 'shared' + vf_latent = policy_latent + + policy = PolicyWithValue( + env=model_config, + observations=X, + latent=policy_latent, + vf_latent=vf_latent, + sess=sess, + estimate_q=estimate_q, + np_mask=np_mask, + is_act_model=is_act_model, + **extra_tensors + ) + return policy + + return policy_fn diff --git a/src/sdk/pynni/nni/ppo_tuner/ppo_tuner.py b/src/sdk/pynni/nni/ppo_tuner/ppo_tuner.py new file mode 100644 index 0000000000..77b232b442 --- /dev/null +++ b/src/sdk/pynni/nni/ppo_tuner/ppo_tuner.py @@ -0,0 +1,589 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +""" +ppo_tuner.py including: + class PPOTuner +""" + +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "" +import copy +import logging +import numpy as np +import json_tricks +from gym import spaces + +import nni +from nni.tuner import Tuner +from nni.utils import OptimizeMode, extract_scalar_reward + +from .model import Model +from .util import set_global_seeds +from .policy import build_lstm_policy + + +logger = logging.getLogger('ppo_tuner_AutoML') + +def constfn(val): + """wrap as function""" + def f(_): + return val + return f + + +class ModelConfig: + """ + Configurations of the PPO model + """ + def __init__(self): + self.observation_space = None + self.action_space = None + self.num_envs = 0 + self.nsteps = 0 + + self.ent_coef = 0.0 + self.lr = 3e-4 + self.vf_coef = 0.5 + self.max_grad_norm = 0.5 + self.gamma = 0.99 + self.lam = 0.95 + self.cliprange = 0.2 + self.embedding_size = None # the embedding is for each action + + self.noptepochs = 4 # number of training epochs per update + self.total_timesteps = 5000 # number of timesteps (i.e. number of actions taken in the environment) + self.nminibatches = 4 # number of training minibatches per update. For recurrent policies, + # should be smaller or equal than number of environments run in parallel. + +class TrialsInfo: + """ + Informations of each trial from one model inference + """ + def __init__(self, obs, actions, values, neglogpacs, dones, last_value, inf_batch_size): + self.iter = 0 + self.obs = obs + self.actions = actions + self.values = values + self.neglogpacs = neglogpacs + self.dones = dones + self.last_value = last_value + + self.rewards = None + self.returns = None + + self.inf_batch_size = inf_batch_size + #self.states = None + + def get_next(self): + """ + get actions of the next trial + """ + if self.iter >= self.inf_batch_size: + return None, None + actions = [] + for step in self.actions: + actions.append(step[self.iter]) + self.iter += 1 + return self.iter - 1, actions + + def update_rewards(self, rewards, returns): + """ + after the trial is finished, reward and return of this trial is updated + """ + self.rewards = rewards + self.returns = returns + + def convert_shape(self): + """ + convert shape + """ + def sf01(arr): + """ + swap and then flatten axes 0 and 1 + """ + s = arr.shape + return arr.swapaxes(0, 1).reshape(s[0] * s[1], *s[2:]) + self.obs = sf01(self.obs) + self.returns = sf01(self.returns) + self.dones = sf01(self.dones) + self.actions = sf01(self.actions) + self.values = sf01(self.values) + self.neglogpacs = sf01(self.neglogpacs) + + +class PPOModel: + """ + PPO Model + """ + def __init__(self, model_config, mask): + self.model_config = model_config + self.states = None # initial state of lstm in policy/value network + self.nupdates = None # the number of func train is invoked, used to tune lr and cliprange + self.cur_update = 1 # record the current update + self.np_mask = mask # record the mask of each action within one trial + + set_global_seeds(None) + assert isinstance(self.model_config.lr, float) + self.lr = constfn(self.model_config.lr) + assert isinstance(self.model_config.cliprange, float) + self.cliprange = constfn(self.model_config.cliprange) + + # build lstm policy network, value share the same network + policy = build_lstm_policy(model_config) + + # Get the nb of env + nenvs = model_config.num_envs + + # Calculate the batch_size + self.nbatch = nbatch = nenvs * model_config.nsteps + nbatch_train = nbatch // model_config.nminibatches + self.nupdates = self.model_config.total_timesteps//self.nbatch + + # Instantiate the model object (that creates act_model and train_model) + self.model = Model(policy=policy, nbatch_act=nenvs, nbatch_train=nbatch_train, + nsteps=model_config.nsteps, ent_coef=model_config.ent_coef, vf_coef=model_config.vf_coef, + max_grad_norm=model_config.max_grad_norm, np_mask=self.np_mask) + + self.states = self.model.initial_state + + logger.info('=== finished PPOModel initialization') + + def inference(self, num): + """ + generate actions along with related info from policy network. + observation is the action of the last step. + + Parameters: + ---------- + num: the number of trials to generate + """ + # Here, we init the lists that will contain the mb of experiences + mb_obs, mb_actions, mb_values, mb_dones, mb_neglogpacs = [], [], [], [], [] + # initial observation + # use the (n+1)th embedding to represent the first step action + first_step_ob = self.model_config.action_space.n + obs = [first_step_ob for _ in range(num)] + dones = [True for _ in range(num)] + states = self.states + # For n in range number of steps + for cur_step in range(self.model_config.nsteps): + # Given observations, get action value and neglopacs + # We already have self.obs because Runner superclass run self.obs[:] = env.reset() on init + actions, values, states, neglogpacs = self.model.step(cur_step, obs, S=states, M=dones) + mb_obs.append(obs.copy()) + mb_actions.append(actions) + mb_values.append(values) + mb_neglogpacs.append(neglogpacs) + mb_dones.append(dones) + + # Take actions in env and look the results + # Infos contains a ton of useful informations + obs[:] = actions + if cur_step == self.model_config.nsteps - 1: + dones = [True for _ in range(num)] + else: + dones = [False for _ in range(num)] + + #batch of steps to batch of rollouts + np_obs = np.asarray(obs) + mb_obs = np.asarray(mb_obs, dtype=np_obs.dtype) + mb_actions = np.asarray(mb_actions) + mb_values = np.asarray(mb_values, dtype=np.float32) + mb_neglogpacs = np.asarray(mb_neglogpacs, dtype=np.float32) + mb_dones = np.asarray(mb_dones, dtype=np.bool) + last_values = self.model.value(np_obs, S=states, M=dones) + + return mb_obs, mb_actions, mb_values, mb_neglogpacs, mb_dones, last_values + + def compute_rewards(self, trials_info, trials_result): + """ + compute the rewards of the trials in trials_info based on trials_result, + and update the rewards in trials_info + + Parameters: + ---------- + trials_info: info of the generated trials + trials_result: final results (e.g., acc) of the generated trials + """ + mb_rewards = np.asarray([trials_result for _ in trials_info.actions], dtype=np.float32) + # discount/bootstrap off value fn + mb_returns = np.zeros_like(mb_rewards) + mb_advs = np.zeros_like(mb_rewards) + lastgaelam = 0 + last_dones = np.asarray([True for _ in trials_result], dtype=np.bool) # ugly + for t in reversed(range(self.model_config.nsteps)): + if t == self.model_config.nsteps - 1: + nextnonterminal = 1.0 - last_dones + nextvalues = trials_info.last_value + else: + nextnonterminal = 1.0 - trials_info.dones[t+1] + nextvalues = trials_info.values[t+1] + delta = mb_rewards[t] + self.model_config.gamma * nextvalues * nextnonterminal - trials_info.values[t] + mb_advs[t] = lastgaelam = delta + self.model_config.gamma * self.model_config.lam * nextnonterminal * lastgaelam + mb_returns = mb_advs + trials_info.values + + trials_info.update_rewards(mb_rewards, mb_returns) + trials_info.convert_shape() + + def train(self, trials_info, nenvs): + """ + train the policy/value network using trials_info + + Parameters: + ---------- + trials_info: complete info of the generated trials from the previous inference + nenvs: the batch size of the (previous) inference + """ + if self.cur_update <= self.nupdates: + frac = 1.0 - (self.cur_update - 1.0) / self.nupdates + else: + logger.warning('current update (self.cur_update) %d has exceeded total updates (self.nupdates) %d', + self.cur_update, self.nupdates) + frac = 1.0 - (self.nupdates - 1.0) / self.nupdates + lrnow = self.lr(frac) + cliprangenow = self.cliprange(frac) + self.cur_update += 1 + + states = self.states + + assert states is not None # recurrent version + assert nenvs % self.model_config.nminibatches == 0 + envsperbatch = nenvs // self.model_config.nminibatches + envinds = np.arange(nenvs) + flatinds = np.arange(nenvs * self.model_config.nsteps).reshape(nenvs, self.model_config.nsteps) + for _ in range(self.model_config.noptepochs): + np.random.shuffle(envinds) + for start in range(0, nenvs, envsperbatch): + end = start + envsperbatch + mbenvinds = envinds[start:end] + mbflatinds = flatinds[mbenvinds].ravel() + slices = (arr[mbflatinds] for arr in (trials_info.obs, trials_info.returns, trials_info.dones, + trials_info.actions, trials_info.values, trials_info.neglogpacs)) + mbstates = states[mbenvinds] + self.model.train(lrnow, cliprangenow, *slices, mbstates) + + +class PPOTuner(Tuner): + """ + PPOTuner + """ + + def __init__(self, optimize_mode, trials_per_update=20, epochs_per_update=4, minibatch_size=4): + """ + initialization, PPO model is not initialized here as search space is not received yet. + + Parameters: + ---------- + optimize_mode: maximize or minimize + trials_per_update: number of trials to have for each model update + epochs_per_update: number of epochs to run for each model update + minibatch_size: minibatch size (number of trials) for the update + """ + self.optimize_mode = OptimizeMode(optimize_mode) + self.model_config = ModelConfig() + self.model = None + self.search_space = None + self.running_trials = {} # key: parameter_id, value: actions/states/etc. + self.inf_batch_size = trials_per_update # number of trials to generate in one inference + self.first_inf = True # indicate whether it is the first time to inference new trials + self.trials_result = [None for _ in range(self.inf_batch_size)] # results of finished trials + + self.credit = 0 # record the unsatisfied trial requests + self.param_ids = [] + self.finished_trials = 0 + self.chosen_arch_template = {} + + self.actions_spaces = None + self.actions_to_config = None + self.full_act_space = None + self.trials_info = None + + self.all_trials = {} # used to dedup the same trial, key: config, value: final result + + self.model_config.num_envs = self.inf_batch_size + self.model_config.noptepochs = epochs_per_update + self.model_config.nminibatches = minibatch_size + + self.send_trial_callback = None + logger.info('=== finished PPOTuner initialization') + + def _process_one_nas_space(self, block_name, block_space): + """ + process nas space to determine observation space and action space + + Parameters: + ---------- + block_name: the name of the mutable block + block_space: search space of this mutable block + + Returns: + ---------- + actions_spaces: list of the space of each action + actions_to_config: the mapping from action to generated configuration + """ + actions_spaces = [] + actions_to_config = [] + + block_arch_temp = {} + for l_name, layer in block_space.items(): + chosen_layer_temp = {} + + if len(layer['layer_choice']) > 1: + actions_spaces.append(layer['layer_choice']) + actions_to_config.append((block_name, l_name, 'chosen_layer')) + chosen_layer_temp['chosen_layer'] = None + else: + assert len(layer['layer_choice']) == 1 + chosen_layer_temp['chosen_layer'] = layer['layer_choice'][0] + + if layer['optional_input_size'] not in [0, 1, [0, 1]]: + raise ValueError('Optional_input_size can only be 0, 1, or [0, 1], but the pecified one is %s' + % (layer['optional_input_size'])) + if isinstance(layer['optional_input_size'], list): + actions_spaces.append(["None", *layer['optional_inputs']]) + actions_to_config.append((block_name, l_name, 'chosen_inputs')) + chosen_layer_temp['chosen_inputs'] = None + elif layer['optional_input_size'] == 1: + actions_spaces.append(layer['optional_inputs']) + actions_to_config.append((block_name, l_name, 'chosen_inputs')) + chosen_layer_temp['chosen_inputs'] = None + elif layer['optional_input_size'] == 0: + chosen_layer_temp['chosen_inputs'] = [] + else: + raise ValueError('invalid type and value of optional_input_size') + + block_arch_temp[l_name] = chosen_layer_temp + + self.chosen_arch_template[block_name] = block_arch_temp + + return actions_spaces, actions_to_config + + def _process_nas_space(self, search_space): + """ + process nas search space to get action/observation space + """ + actions_spaces = [] + actions_to_config = [] + for b_name, block in search_space.items(): + if block['_type'] != 'mutable_layer': + raise ValueError('PPOTuner only accept mutable_layer type in search space, but the current one is %s'%(block['_type'])) + block = block['_value'] + act, act_map = self._process_one_nas_space(b_name, block) + actions_spaces.extend(act) + actions_to_config.extend(act_map) + + # calculate observation space + dedup = {} + for step in actions_spaces: + for action in step: + dedup[action] = 1 + full_act_space = [act for act, _ in dedup.items()] + assert len(full_act_space) == len(dedup) + observation_space = len(full_act_space) + + nsteps = len(actions_spaces) + + return actions_spaces, actions_to_config, full_act_space, observation_space, nsteps + + def _generate_action_mask(self): + """ + different step could have different action space. to deal with this case, we merge all the + possible actions into one action space, and use mask to indicate available actions for each step + """ + two_masks = [] + + mask = [] + for acts in self.actions_spaces: + one_mask = [0 for _ in range(len(self.full_act_space))] + for act in acts: + idx = self.full_act_space.index(act) + one_mask[idx] = 1 + mask.append(one_mask) + two_masks.append(mask) + + mask = [] + for acts in self.actions_spaces: + one_mask = [-np.inf for _ in range(len(self.full_act_space))] + for act in acts: + idx = self.full_act_space.index(act) + one_mask[idx] = 0 + mask.append(one_mask) + two_masks.append(mask) + + return np.asarray(two_masks, dtype=np.float32) + + def update_search_space(self, search_space): + """ + get search space, currently the space only includes that for NAS + + Parameters: + ---------- + search_space: search space for NAS + + Returns: + ------- + no return + """ + logger.info('=== update search space %s', search_space) + assert self.search_space is None + self.search_space = search_space + + assert self.model_config.observation_space is None + assert self.model_config.action_space is None + + self.actions_spaces, self.actions_to_config, self.full_act_space, obs_space, nsteps = self._process_nas_space(search_space) + + self.model_config.observation_space = spaces.Discrete(obs_space) + self.model_config.action_space = spaces.Discrete(obs_space) + self.model_config.nsteps = nsteps + + # generate mask in numpy + mask = self._generate_action_mask() + + assert self.model is None + self.model = PPOModel(self.model_config, mask) + + def _actions_to_config(self, actions): + """ + given actions, to generate the corresponding trial configuration + """ + chosen_arch = copy.deepcopy(self.chosen_arch_template) + for cnt, act in enumerate(actions): + act_name = self.full_act_space[act] + (block_name, layer_name, key) = self.actions_to_config[cnt] + if key == 'chosen_inputs': + if act_name == 'None': + chosen_arch[block_name][layer_name][key] = [] + else: + chosen_arch[block_name][layer_name][key] = [act_name] + elif key == 'chosen_layer': + chosen_arch[block_name][layer_name][key] = act_name + else: + raise ValueError('unrecognized key: {0}'.format(key)) + return chosen_arch + + def generate_multiple_parameters(self, parameter_id_list, **kwargs): + """ + Returns multiple sets of trial (hyper-)parameters, as iterable of serializable objects. + """ + result = [] + self.send_trial_callback = kwargs['st_callback'] + for parameter_id in parameter_id_list: + had_exception = False + try: + logger.debug("generating param for %s", parameter_id) + res = self.generate_parameters(parameter_id, **kwargs) + except nni.NoMoreTrialError: + had_exception = True + if not had_exception: + result.append(res) + return result + + def generate_parameters(self, parameter_id, **kwargs): + """ + generate parameters, if no trial configration for now, self.credit plus 1 to send the config later + """ + if self.first_inf: + self.trials_result = [None for _ in range(self.inf_batch_size)] + mb_obs, mb_actions, mb_values, mb_neglogpacs, mb_dones, last_values = self.model.inference(self.inf_batch_size) + self.trials_info = TrialsInfo(mb_obs, mb_actions, mb_values, mb_neglogpacs, + mb_dones, last_values, self.inf_batch_size) + self.first_inf = False + + trial_info_idx, actions = self.trials_info.get_next() + if trial_info_idx is None: + self.credit += 1 + self.param_ids.append(parameter_id) + raise nni.NoMoreTrialError('no more parameters now.') + + self.running_trials[parameter_id] = trial_info_idx + new_config = self._actions_to_config(actions) + return new_config + + def _next_round_inference(self): + """ + """ + self.finished_trials = 0 + self.model.compute_rewards(self.trials_info, self.trials_result) + self.model.train(self.trials_info, self.inf_batch_size) + self.running_trials = {} + # generate new trials + self.trials_result = [None for _ in range(self.inf_batch_size)] + mb_obs, mb_actions, mb_values, mb_neglogpacs, mb_dones, last_values = self.model.inference(self.inf_batch_size) + self.trials_info = TrialsInfo(mb_obs, mb_actions, mb_values, mb_neglogpacs, + mb_dones, last_values, self.inf_batch_size) + # check credit and submit new trials + for _ in range(self.credit): + trial_info_idx, actions = self.trials_info.get_next() + if trial_info_idx is None: + logger.warning('No enough trial config, trials_per_update is suggested to be larger than trialConcurrency') + break + assert self.param_ids + param_id = self.param_ids.pop() + self.running_trials[param_id] = trial_info_idx + new_config = self._actions_to_config(actions) + self.send_trial_callback(param_id, new_config) + self.credit -= 1 + + def receive_trial_result(self, parameter_id, parameters, value, **kwargs): + """ + receive trial's result. if the number of finished trials equals self.inf_batch_size, start the next update to + train the model + """ + trial_info_idx = self.running_trials.pop(parameter_id, None) + assert trial_info_idx is not None + + value = extract_scalar_reward(value) + if self.optimize_mode == OptimizeMode.Minimize: + value = -value + + self.trials_result[trial_info_idx] = value + self.finished_trials += 1 + + if self.finished_trials == self.inf_batch_size: + self._next_round_inference() + + def trial_end(self, parameter_id, success, **kwargs): + """ + to deal with trial failure + """ + if not success: + if parameter_id not in self.running_trials: + logger.warning('The trial is failed, but self.running_trial does not have this trial') + return + trial_info_idx = self.running_trials.pop(parameter_id, None) + assert trial_info_idx is not None + # use mean of finished trials as the result of this failed trial + values = [val for val in self.trials_result if val is not None] + logger.warning('zql values: {0}'.format(values)) + self.trials_result[trial_info_idx] = (sum(values) / len(values)) if len(values) > 0 else 0 + self.finished_trials += 1 + if self.finished_trials == self.inf_batch_size: + self._next_round_inference() + + def import_data(self, data): + """ + Import additional data for tuning + + Parameters + ---------- + data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value' + """ + logger.warning('PPOTuner cannot leverage imported data.') diff --git a/src/sdk/pynni/nni/ppo_tuner/requirements.txt b/src/sdk/pynni/nni/ppo_tuner/requirements.txt new file mode 100644 index 0000000000..138951469b --- /dev/null +++ b/src/sdk/pynni/nni/ppo_tuner/requirements.txt @@ -0,0 +1,3 @@ +enum34 +gym +tensorflow \ No newline at end of file diff --git a/src/sdk/pynni/nni/ppo_tuner/util.py b/src/sdk/pynni/nni/ppo_tuner/util.py new file mode 100644 index 0000000000..ac958e54de --- /dev/null +++ b/src/sdk/pynni/nni/ppo_tuner/util.py @@ -0,0 +1,266 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +""" +util functions +""" + +import os +import random +import multiprocessing +import numpy as np +import tensorflow as tf +from gym.spaces import Discrete, Box, MultiDiscrete + +def set_global_seeds(i): + """set global seeds""" + rank = 0 + myseed = i + 1000 * rank if i is not None else None + tf.set_random_seed(myseed) + np.random.seed(myseed) + random.seed(myseed) + +def batch_to_seq(h, nbatch, nsteps, flat=False): + """convert from batch to sequence""" + if flat: + h = tf.reshape(h, [nbatch, nsteps]) + else: + h = tf.reshape(h, [nbatch, nsteps, -1]) + return [tf.squeeze(v, [1]) for v in tf.split(axis=1, num_or_size_splits=nsteps, value=h)] + +def seq_to_batch(h, flat=False): + """convert from sequence to batch""" + shape = h[0].get_shape().as_list() + if not flat: + assert len(shape) > 1 + nh = h[0].get_shape()[-1].value + return tf.reshape(tf.concat(axis=1, values=h), [-1, nh]) + else: + return tf.reshape(tf.stack(values=h, axis=1), [-1]) + +def lstm(xs, ms, s, scope, nh, init_scale=1.0): + """lstm cell""" + nbatch, nin = [v.value for v in xs[0].get_shape()] + with tf.variable_scope(scope): + wx = tf.get_variable("wx", [nin, nh*4], initializer=ortho_init(init_scale)) + wh = tf.get_variable("wh", [nh, nh*4], initializer=ortho_init(init_scale)) + b = tf.get_variable("b", [nh*4], initializer=tf.constant_initializer(0.0)) + + c, h = tf.split(axis=1, num_or_size_splits=2, value=s) + for idx, (x, m) in enumerate(zip(xs, ms)): + c = c*(1-m) + h = h*(1-m) + z = tf.matmul(x, wx) + tf.matmul(h, wh) + b + i, f, o, u = tf.split(axis=1, num_or_size_splits=4, value=z) + i = tf.nn.sigmoid(i) + f = tf.nn.sigmoid(f) + o = tf.nn.sigmoid(o) + u = tf.tanh(u) + c = f*c + i*u + h = o*tf.tanh(c) + xs[idx] = h + s = tf.concat(axis=1, values=[c, h]) + return xs, s + +def lstm_model(nlstm=128, layer_norm=False): + """ + Builds LSTM (Long-Short Term Memory) network to be used in a policy. + Note that the resulting function returns not only the output of the LSTM + (i.e. hidden state of lstm for each step in the sequence), but also a dictionary + with auxiliary tensors to be set as policy attributes. + + Specifically, + S is a placeholder to feed current state (LSTM state has to be managed outside policy) + M is a placeholder for the mask (used to mask out observations after the end of the episode, but can be used for other purposes too) + initial_state is a numpy array containing initial lstm state (usually zeros) + state is the output LSTM state (to be fed into S at the next call) + + + An example of usage of lstm-based policy can be found here: common/tests/test_doc_examples.py/test_lstm_example + + Parameters: + ---------- + nlstm: int LSTM hidden state size + layer_norm: bool if True, layer-normalized version of LSTM is used + + Returns: + ------- + function that builds LSTM with a given input tensor / placeholder + """ + + def network_fn(X, nenv=1, obs_size=-1): + with tf.variable_scope("emb", reuse=tf.AUTO_REUSE): + w_emb = tf.get_variable("w_emb", [obs_size+1, 32]) + X = tf.nn.embedding_lookup(w_emb, X) + + nbatch = X.shape[0] + nsteps = nbatch // nenv + + h = tf.layers.flatten(X) + + M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1) + S = tf.placeholder(tf.float32, [nenv, 2*nlstm]) #states + + xs = batch_to_seq(h, nenv, nsteps) + ms = batch_to_seq(M, nenv, nsteps) + + assert not layer_norm + h5, snew = lstm(xs, ms, S, scope='lstm', nh=nlstm) + + h = seq_to_batch(h5) + initial_state = np.zeros(S.shape.as_list(), dtype=float) + + return h, {'S':S, 'M':M, 'state':snew, 'initial_state':initial_state} + + return network_fn + +def ortho_init(scale=1.0): + """init approach""" + def _ortho_init(shape, dtype, partition_info=None): + #lasagne ortho init for tf + shape = tuple(shape) + if len(shape) == 2: + flat_shape = shape + elif len(shape) == 4: # assumes NHWC + flat_shape = (np.prod(shape[:-1]), shape[-1]) + else: + raise NotImplementedError + a = np.random.normal(0.0, 1.0, flat_shape) + u, _, v = np.linalg.svd(a, full_matrices=False) + q = u if u.shape == flat_shape else v # pick the one with the correct shape + q = q.reshape(shape) + return (scale * q[:shape[0], :shape[1]]).astype(np.float32) + return _ortho_init + +def fc(x, scope, nh, *, init_scale=1.0, init_bias=0.0): + """fully connected op""" + with tf.variable_scope(scope): + nin = x.get_shape()[1].value + w = tf.get_variable("w", [nin, nh], initializer=ortho_init(init_scale)) + b = tf.get_variable("b", [nh], initializer=tf.constant_initializer(init_bias)) + return tf.matmul(x, w)+b + +def _check_shape(placeholder_shape, data_shape): + """ + check if two shapes are compatible (i.e. differ only by dimensions of size 1, or by the batch dimension) + """ + + return True + +# ================================================================ +# Shape adjustment for feeding into tf placeholders +# ================================================================ +def adjust_shape(placeholder, data): + """ + adjust shape of the data to the shape of the placeholder if possible. + If shape is incompatible, AssertionError is thrown + + Parameters: + placeholder: tensorflow input placeholder + data: input data to be (potentially) reshaped to be fed into placeholder + + Returns: + reshaped data + """ + if not isinstance(data, np.ndarray) and not isinstance(data, list): + return data + if isinstance(data, list): + data = np.array(data) + + placeholder_shape = [x or -1 for x in placeholder.shape.as_list()] + + assert _check_shape(placeholder_shape, data.shape), \ + 'Shape of data {} is not compatible with shape of the placeholder {}'.format(data.shape, placeholder_shape) + + return np.reshape(data, placeholder_shape) + +# ================================================================ +# Global session +# ================================================================ + +def get_session(config=None): + """Get default session or create one with a given config""" + sess = tf.get_default_session() + if sess is None: + sess = make_session(config=config, make_default=True) + return sess + +def make_session(config=None, num_cpu=None, make_default=False, graph=None): + """Returns a session that will use CPU's only""" + if num_cpu is None: + num_cpu = int(os.getenv('RCALL_NUM_CPU', multiprocessing.cpu_count())) + if config is None: + config = tf.ConfigProto( + allow_soft_placement=True, + inter_op_parallelism_threads=num_cpu, + intra_op_parallelism_threads=num_cpu) + config.gpu_options.allow_growth = True + + if make_default: + return tf.InteractiveSession(config=config, graph=graph) + else: + return tf.Session(config=config, graph=graph) + +ALREADY_INITIALIZED = set() + +def initialize(): + """Initialize all the uninitialized variables in the global scope.""" + new_variables = set(tf.global_variables()) - ALREADY_INITIALIZED + get_session().run(tf.variables_initializer(new_variables)) + + ALREADY_INITIALIZED.update(new_variables) + +def observation_placeholder(ob_space, batch_size=None, name='Ob'): + """ + Create placeholder to feed observations into of the size appropriate to the observation space + + Parameters: + ---------- + ob_space: gym.Space observation space + batch_size: int size of the batch to be fed into input. Can be left None in most cases. + name: str name of the placeholder + + Returns: + ------- + tensorflow placeholder tensor + """ + + assert isinstance(ob_space, (Discrete, Box, MultiDiscrete)), \ + 'Can only deal with Discrete and Box observation spaces for now' + + dtype = ob_space.dtype + if dtype == np.int8: + dtype = np.uint8 + + return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name) + +def explained_variance(ypred, y): + """ + Computes fraction of variance that ypred explains about y. + Returns 1 - Var[y-ypred] / Var[y] + + interpretation: + ev=0 => might as well have predicted zero + ev=1 => perfect prediction + ev<0 => worse than just predicting zero + + """ + assert y.ndim == 1 and ypred.ndim == 1 + vary = np.var(y) + return np.nan if vary == 0 else 1 - np.var(y-ypred)/vary diff --git a/tools/nni_cmd/config_schema.py b/tools/nni_cmd/config_schema.py index d5dbe49dcd..4f4c9dce08 100644 --- a/tools/nni_cmd/config_schema.py +++ b/tools/nni_cmd/config_schema.py @@ -142,6 +142,17 @@ def setPathCheck(key): Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool), Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999), }, + 'PPOTuner': { + 'builtinTunerName': 'PPOTuner', + 'classArgs': { + 'optimize_mode': setChoice('optimize_mode', 'maximize', 'minimize'), + Optional('trials_per_update'): setNumberRange('trials_per_update', int, 0, 99999), + Optional('epochs_per_update'): setNumberRange('epochs_per_update', int, 0, 99999), + Optional('minibatch_size'): setNumberRange('minibatch_size', int, 0, 99999), + }, + Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool), + Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999), + }, 'customized': { 'codeDir': setPathCheck('codeDir'), 'classFileName': setType('classFileName', str), diff --git a/tools/nni_cmd/constants.py b/tools/nni_cmd/constants.py index 04a3dbbaff..d22a509c46 100644 --- a/tools/nni_cmd/constants.py +++ b/tools/nni_cmd/constants.py @@ -80,7 +80,8 @@ PACKAGE_REQUIREMENTS = { 'SMAC': 'smac_tuner', - 'BOHB': 'bohb_advisor' + 'BOHB': 'bohb_advisor', + 'PPOTuner': 'ppo_tuner' } TUNERS_SUPPORTING_IMPORT_DATA = {