Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
PPO tuner for NAS, supports NNI's NAS interface (#1380)
Browse files Browse the repository at this point in the history
* ppo tuner
  • Loading branch information
QuanluZhang authored Aug 7, 2019
1 parent 3a6d137 commit e470eef
Show file tree
Hide file tree
Showing 14 changed files with 1,514 additions and 4 deletions.
30 changes: 30 additions & 0 deletions docs/en_US/Tuner/BuiltinTuner.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Currently we support the following algorithms:
|[__Metis Tuner__](#MetisTuner)|Metis offers the following benefits when it comes to tuning parameters: While most tools only predict the optimal configuration, Metis gives you two outputs: (a) current prediction of optimal configuration, and (b) suggestion for the next trial. No more guesswork. While most tools assume training datasets do not have noisy data, Metis actually tells you if you need to re-sample a particular hyper-parameter. [Reference Paper](https://www.microsoft.com/en-us/research/publication/metis-robustly-tuning-tail-latencies-cloud-systems/)|
|[__BOHB__](#BOHB)|BOHB is a follow-up work of Hyperband. It targets the weakness of Hyperband that new configurations are generated randomly without leveraging finished trials. For the name BOHB, HB means Hyperband, BO means Bayesian Optimization. BOHB leverages finished trials by building multiple TPE models, a proportion of new configurations are generated through these models. [Reference Paper](https://arxiv.org/abs/1807.01774)|
|[__GP Tuner__](#GPTuner)|Gaussian Process Tuner is a sequential model-based optimization (SMBO) approach with Gaussian Process as the surrogate. [Reference Paper](https://papers.nips.cc/paper/4443-algorithms-for-hyper-parameter-optimization.pdf), [Github Repo](https://github.com/fmfn/BayesianOptimization)|
|[__PPO Tuner__](#PPOTuner)|PPO Tuner is an Reinforcement Learning tuner based on PPO algorithm. [Reference Paper](https://arxiv.org/abs/1707.06347)|

## Usage of Built-in Tuners

Expand Down Expand Up @@ -409,3 +410,32 @@ tuner:
selection_num_warm_up: 100000
selection_num_starting_points: 250
```

<a name="PPOTuner"></a>

![](https://placehold.it/15/1589F0/000000?text=+) `PPO Tuner`

> Built-in Tuner Name: **PPOTuner**

Note that the only acceptable type of search space is `mutable_layer`. `optional_input_size` can only be 0, 1, or [0, 1].

**Suggested scenario**

PPOTuner is a Reinforcement Learning tuner based on PPO algorithm. When you are using NNI NAS interface in your trial code to do neural architecture search, PPOTuner is recommended. It has relatively high data efficiency but is suggested when you have large amount of computation resource. You could try it on very simple task, such as the [mnist-nas](https://github.com/microsoft/nni/tree/master/examples/trials/mnist-nas) example. [Detailed Description](./PPOTuner.md)

**Requirement of classArg**

* **optimize_mode** (*'maximize' or 'minimize'*) - If 'maximize', the tuner will target to maximize metrics. If 'minimize', the tuner will target to minimize metrics.
* **trials_per_update** (*int, optional, default = 20*) - The number of trials to be used for one update. This number is recommended to be larger than `trialConcurrency` and `trialConcurrency` be a aliquot devisor of `trials_per_update`.
* **epochs_per_update** (*int, optional, default = 4*) - The number of epochs for one update.
* **minibatch_size** (*int, optional, default = 4*) - Mini-batch size (i.e., number of trials for a mini-batch) for the update. Note that, trials_per_update should be divisible of minibatch_size.

**Usage example**

```yaml
# config.yml
tuner:
builtinTunerName: PPOTuner
classArgs:
optimize_mode: maximize
```
20 changes: 20 additions & 0 deletions examples/trials/mnist-nas/config_ppo.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
authorName: NNI-example
experimentName: example_mnist
trialConcurrency: 1
maxExecDuration: 100h
maxTrialNum: 10000
#choice: local, remote, pai
trainingServicePlatform: local
#choice: true, false
useAnnotation: true
tuner:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner
#SMAC (SMAC should be installed through nnictl)
#codeDir: ~/nni/nni/examples/tuners/random_nas_tuner
builtinTunerName: PPOTuner
classArgs:
optimize_mode: maximize
trial:
command: python3 mnist.py
codeDir: .
gpuNum: 0
2 changes: 1 addition & 1 deletion src/nni_manager/rest_server/restValidationSchemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ export namespace ValidationSchemas {
checkpointDir: joi.string().allow('')
}),
tuner: joi.object({
builtinTunerName: joi.string().valid('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner', 'GridSearch', 'NetworkMorphism', 'MetisTuner', 'GPTuner'),
builtinTunerName: joi.string().valid('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner', 'GridSearch', 'NetworkMorphism', 'MetisTuner', 'GPTuner', 'PPOTuner'),
codeDir: joi.string(),
classFileName: joi.string(),
className: joi.string(),
Expand Down
4 changes: 3 additions & 1 deletion src/sdk/pynni/nni/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
'NetworkMorphism': 'nni.networkmorphism_tuner.networkmorphism_tuner',
'Curvefitting': 'nni.curvefitting_assessor.curvefitting_assessor',
'MetisTuner': 'nni.metis_tuner.metis_tuner',
'GPTuner': 'nni.gp_tuner.gp_tuner'
'GPTuner': 'nni.gp_tuner.gp_tuner',
'PPOTuner': 'nni.ppo_tuner.ppo_tuner'
}

ClassName = {
Expand All @@ -44,6 +45,7 @@
'NetworkMorphism':'NetworkMorphismTuner',
'MetisTuner':'MetisTuner',
'GPTuner':'GPTuner',
'PPOTuner': 'PPOTuner',

'Medianstop': 'MedianstopAssessor',
'Curvefitting': 'CurvefittingAssessor'
Expand Down
7 changes: 6 additions & 1 deletion src/sdk/pynni/nni/msg_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,16 @@ def handle_initialize(self, data):
self.tuner.update_search_space(data)
send(CommandType.Initialized, '')

def send_trial_callback(self, id, params):
"""For tuner to issue trial config when the config is generated
"""
send(CommandType.NewTrialJob, _pack_parameter(id, params))

def handle_request_trial_jobs(self, data):
# data: number or trial jobs
ids = [_create_parameter_id() for _ in range(data)]
_logger.debug("requesting for generating params of {}".format(ids))
params_list = self.tuner.generate_multiple_parameters(ids)
params_list = self.tuner.generate_multiple_parameters(ids, st_callback=self.send_trial_callback)

for i, _ in enumerate(params_list):
send(CommandType.NewTrialJob, _pack_parameter(ids[i], params_list[i]))
Expand Down
Empty file.
198 changes: 198 additions & 0 deletions src/sdk/pynni/nni/ppo_tuner/distri.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
functions for sampling from hidden state
"""

import tensorflow as tf

from .util import fc


class Pd:
"""
A particular probability distribution
"""
def flatparam(self):
raise NotImplementedError
def mode(self):
raise NotImplementedError
def neglogp(self, x):
# Usually it's easier to define the negative logprob
raise NotImplementedError
def kl(self, other):
raise NotImplementedError
def entropy(self):
raise NotImplementedError
def sample(self):
raise NotImplementedError
def logp(self, x):
return - self.neglogp(x)
def get_shape(self):
return self.flatparam().shape
@property
def shape(self):
return self.get_shape()
def __getitem__(self, idx):
return self.__class__(self.flatparam()[idx])

class PdType:
"""
Parametrized family of probability distributions
"""
def pdclass(self):
raise NotImplementedError
def pdfromflat(self, flat, mask, nsteps, size, is_act_model):
return self.pdclass()(flat, mask, nsteps, size, is_act_model)
def pdfromlatent(self, latent_vector, init_scale, init_bias):
raise NotImplementedError
def param_shape(self):
raise NotImplementedError
def sample_shape(self):
raise NotImplementedError
def sample_dtype(self):
raise NotImplementedError

def param_placeholder(self, prepend_shape, name=None):
return tf.placeholder(dtype=tf.float32, shape=prepend_shape+self.param_shape(), name=name)
def sample_placeholder(self, prepend_shape, name=None):
return tf.placeholder(dtype=self.sample_dtype(), shape=prepend_shape+self.sample_shape(), name=name)

class CategoricalPd(Pd):
"""
categorical prossibility distribution
"""
def __init__(self, logits, mask_npinf, nsteps, size, is_act_model):
self.logits = logits
self.mask_npinf = mask_npinf
self.nsteps = nsteps
self.size = size
self.is_act_model = is_act_model
def flatparam(self):
return self.logits
def mode(self):
return tf.argmax(self.logits, axis=-1)

@property
def mean(self):
return tf.nn.softmax(self.logits)
def neglogp(self, x):
"""
return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x)
Note: we can't use sparse_softmax_cross_entropy_with_logits because
the implementation does not allow second-order derivatives...
"""
if x.dtype in {tf.uint8, tf.int32, tf.int64}:
# one-hot encoding
x_shape_list = x.shape.as_list()
logits_shape_list = self.logits.get_shape().as_list()[:-1]
for xs, ls in zip(x_shape_list, logits_shape_list):
if xs is not None and ls is not None:
assert xs == ls, 'shape mismatch: {} in x vs {} in logits'.format(xs, ls)

x = tf.one_hot(x, self.logits.get_shape().as_list()[-1])
else:
# already encoded
assert x.shape.as_list() == self.logits.shape.as_list()

return tf.nn.softmax_cross_entropy_with_logits_v2(
logits=self.logits,
labels=x)

def kl(self, other):
"""kl"""
a0 = self.logits - tf.reduce_max(self.logits, axis=-1, keepdims=True)
a1 = other.logits - tf.reduce_max(other.logits, axis=-1, keepdims=True)
ea0 = tf.exp(a0)
ea1 = tf.exp(a1)
z0 = tf.reduce_sum(ea0, axis=-1, keepdims=True)
z1 = tf.reduce_sum(ea1, axis=-1, keepdims=True)
p0 = ea0 / z0
return tf.reduce_sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), axis=-1)

def entropy(self):
"""compute entropy"""
a0 = self.logits - tf.reduce_max(self.logits, axis=-1, keepdims=True)
ea0 = tf.exp(a0)
z0 = tf.reduce_sum(ea0, axis=-1, keepdims=True)
p0 = ea0 / z0
return tf.reduce_sum(p0 * (tf.log(z0) - a0), axis=-1)

def sample(self):
"""sample from logits"""
if not self.is_act_model:
re_res = tf.reshape(self.logits, [-1, self.nsteps, self.size])
masked_res = tf.math.add(re_res, self.mask_npinf)
re_masked_res = tf.reshape(masked_res, [-1, self.size])

u = tf.random_uniform(tf.shape(re_masked_res), dtype=self.logits.dtype)
return tf.argmax(re_masked_res - tf.log(-tf.log(u)), axis=-1)
else:
u = tf.random_uniform(tf.shape(self.logits), dtype=self.logits.dtype)
return tf.argmax(self.logits - tf.log(-tf.log(u)), axis=-1)

@classmethod
def fromflat(cls, flat):
return cls(flat)

class CategoricalPdType(PdType):
"""
to create CategoricalPd
"""
def __init__(self, ncat, nsteps, np_mask, is_act_model):
self.ncat = ncat
self.nsteps = nsteps
self.np_mask = np_mask
self.is_act_model = is_act_model
def pdclass(self):
return CategoricalPd

def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
"""add fc and create CategoricalPd"""
pdparam, mask, mask_npinf = _matching_fc(latent_vector, 'pi', self.ncat, self.nsteps,
init_scale=init_scale, init_bias=init_bias,
np_mask=self.np_mask, is_act_model=self.is_act_model)
return self.pdfromflat(pdparam, mask_npinf, self.nsteps, self.ncat, self.is_act_model), pdparam, mask, mask_npinf

def param_shape(self):
return [self.ncat]
def sample_shape(self):
return []
def sample_dtype(self):
return tf.int32

def _matching_fc(tensor, name, size, nsteps, init_scale, init_bias, np_mask, is_act_model):
"""
add fc op, and add mask op when not in action mode
"""
if tensor.shape[-1] == size:
assert False
return tensor
else:
mask = tf.get_variable("act_mask", dtype=tf.float32, initializer=np_mask[0], trainable=False)
mask_npinf = tf.get_variable("act_mask_npinf", dtype=tf.float32, initializer=np_mask[1], trainable=False)
res = fc(tensor, name, size, init_scale=init_scale, init_bias=init_bias)
if not is_act_model:
re_res = tf.reshape(res, [-1, nsteps, size])
masked_res = tf.math.multiply(re_res, mask)
re_masked_res = tf.reshape(masked_res, [-1, size])
return re_masked_res, mask, mask_npinf
else:
return res, mask, mask_npinf
Loading

0 comments on commit e470eef

Please sign in to comment.