diff --git a/docs/HowToChooseTuner.md b/docs/HowToChooseTuner.md index 48e2ee15d0..e1dc531095 100644 --- a/docs/HowToChooseTuner.md +++ b/docs/HowToChooseTuner.md @@ -11,6 +11,8 @@ For now, NNI has supported the following tuner algorithms. Note that NNI install - [Grid Search](#Grid) - [Hyperband](#Hyperband) - [Network Morphism](#NetworkMorphism) (require pyTorch) + - [Metis Tuner](#MetisTuner) (require sklearn) + ## Supported tuner algorithms @@ -178,7 +180,7 @@ _Usage_: **Network Morphism** -[Network Morphism](7) provides functions to automatically search for architecture of deep learning models. Every child network inherits the knowledge from its parent network and morphs into diverse types of networks, including changes of depth, width and skip-connection. Next, it estimates the value of child network using the history architecture and metric pairs. Then it selects the most promising one to train. More detail can be referred to [here](../src/sdk/pynni/nni/networkmorphism_tuner/README.md). +[Network Morphism][7] provides functions to automatically search for architecture of deep learning models. Every child network inherits the knowledge from its parent network and morphs into diverse types of networks, including changes of depth, width and skip-connection. Next, it estimates the value of child network using the history architecture and metric pairs. Then it selects the most promising one to train. More detail can be referred to [here](../src/sdk/pynni/nni/networkmorphism_tuner/README.md). _Installation_: NetworkMorphism requires [pyTorch](https://pytorch.org/get-started/locally), so users should install it first. @@ -205,6 +207,43 @@ _Usage_: ``` + +**Metis Tuner** + +[Metis][10] offers the following benefits when it comes to tuning parameters: +While most tools only predicts the optimal configuration, Metis gives you two outputs: (a) current prediction of optimal configuration, and (b) suggestion for the next trial. No more guess work! + +While most tools assume training datasets do not have noisy data, Metis actually tells you if you need to re-sample a particular hyper-parameter. + +While most tools have problems of being exploitation-heavy, Metis' search strategy balances exploration, exploitation, and (optional) re-sampling. + +Metis belongs to the class of sequential model-based optimization (SMBO), and it is based on the Bayesian Optimization framework. To model the parameter-vs-performance space, Metis uses both Gaussian Process and GMM. Since each trial can impose a high time cost, Metis heavily trades inference computations with naive trial. At each iteration, Metis does two tasks: +* It finds the global optimal point in the Gaussian Process space. This point represents the optimal configuration. +* It identifies the next hyper-parameter candidate. This is achieved by inferring the potential information gain of exploration, exploitation, and re-sampling. + +Note that the only acceptable types of search space are `choice`, `quniform`, `uniform` and `randint`. + +More details can be found in our paper: https://www.microsoft.com/en-us/research/publication/metis-robustly-tuning-tail-latencies-cloud-systems/ + + +_Installation_: +Metis Tuner requires [sklearn](https://scikit-learn.org/), so users should install it first. User could use `pip3 install sklearn` to install it. + + +_Suggested scenario_: +Similar to TPE and SMAC, Metis is a black-box tuner. If your system takes a long time to finish each trial, Metis is more favorable than other approaches such as random search. Furthermore, Metis provides guidance on the subsequent trial. Here is an [example](../examples/trials/auto-gbdt/search_space_metis.json) about the use of Metis. User only need to send the final result like `accuracy` to tuner, by calling the nni SDK. + +_Usage_: +```yaml + # config.yaml + tuner: + builtinTunerName: MetisTuner + classArgs: + #choice: maximize, minimize + optimize_mode: maximize +``` + + # How to use Assessor that NNI supports? For now, NNI has supported the following assessor algorithms. @@ -273,3 +312,4 @@ _Usage_: [7]: https://arxiv.org/abs/1806.10282 [8]: https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/46180.pdf [9]: http://aad.informatik.uni-freiburg.de/papers/15-IJCAI-Extrapolation_of_Learning_Curves.pdf +[10]:https://www.microsoft.com/en-us/research/publication/metis-robustly-tuning-tail-latencies-cloud-systems/ diff --git a/examples/trials/auto-gbdt/config_metis.yml b/examples/trials/auto-gbdt/config_metis.yml new file mode 100644 index 0000000000..b52d53e69f --- /dev/null +++ b/examples/trials/auto-gbdt/config_metis.yml @@ -0,0 +1,21 @@ +authorName: default +experimentName: example_auto-gbdt-metis +trialConcurrency: 1 +maxExecDuration: 10h +maxTrialNum: 10 +#choice: local, remote, pai +trainingServicePlatform: local +searchSpacePath: search_space_metis.json +#choice: true, false +useAnnotation: false +tuner: + #choice: TPE, Random, Anneal, Evolution, BatchTuner + #SMAC (SMAC should be installed through nnictl) + builtinTunerName: MetisTuner + classArgs: + #choice: maximize, minimize + optimize_mode: minimize +trial: + command: python3 main.py + codeDir: . + gpuNum: 0 diff --git a/examples/trials/auto-gbdt/requirments.txt b/examples/trials/auto-gbdt/requirments.txt index 87509da343..182230bed8 100644 --- a/examples/trials/auto-gbdt/requirments.txt +++ b/examples/trials/auto-gbdt/requirments.txt @@ -1 +1,2 @@ -pip install lightgbm +lightgbm +pandas diff --git a/examples/trials/auto-gbdt/search_space_metis.json b/examples/trials/auto-gbdt/search_space_metis.json new file mode 100644 index 0000000000..6bfbc32afa --- /dev/null +++ b/examples/trials/auto-gbdt/search_space_metis.json @@ -0,0 +1,5 @@ +{ + "num_leaves":{"_type":"choice","_value":[31, 28, 24, 20]}, + "learning_rate":{"_type":"choice","_value":[0.01, 0.05, 0.1, 0.2]}, + "bagging_freq":{"_type":"choice","_value":[1, 2, 4, 8, 10]} +} diff --git a/src/nni_manager/rest_server/restValidationSchemas.ts b/src/nni_manager/rest_server/restValidationSchemas.ts index bfb1ff24d2..5b6bd1dab7 100644 --- a/src/nni_manager/rest_server/restValidationSchemas.ts +++ b/src/nni_manager/rest_server/restValidationSchemas.ts @@ -148,7 +148,7 @@ export namespace ValidationSchemas { checkpointDir: joi.string().allow('') }), tuner: joi.object({ - builtinTunerName: joi.string().valid('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner', 'GridSearch', 'NetworkMorphism'), + builtinTunerName: joi.string().valid('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner', 'GridSearch', 'NetworkMorphism', 'MetisTuner'), codeDir: joi.string(), classFileName: joi.string(), className: joi.string(), diff --git a/src/sdk/pynni/nni/constants.py b/src/sdk/pynni/nni/constants.py index ba6d27144f..f6cce5adba 100644 --- a/src/sdk/pynni/nni/constants.py +++ b/src/sdk/pynni/nni/constants.py @@ -28,7 +28,8 @@ 'Medianstop': 'nni.medianstop_assessor.medianstop_assessor', 'GridSearch': 'nni.gridsearch_tuner.gridsearch_tuner', 'NetworkMorphism': 'nni.networkmorphism_tuner.networkmorphism_tuner', - 'Curvefitting': 'nni.curvefitting_assessor.curvefitting_assessor' + 'Curvefitting': 'nni.curvefitting_assessor.curvefitting_assessor', + 'MetisTuner': 'nni.metis_tuner.metis_tuner' } ClassName = { @@ -40,6 +41,7 @@ 'BatchTuner': 'BatchTuner', 'GridSearch': 'GridSearchTuner', 'NetworkMorphism':'NetworkMorphismTuner', + 'MetisTuner':'MetisTuner', 'Medianstop': 'MedianstopAssessor', 'Curvefitting': 'CurvefittingAssessor' diff --git a/src/sdk/pynni/nni/metis_tuner/Regression_GMM/CreateModel.py b/src/sdk/pynni/nni/metis_tuner/Regression_GMM/CreateModel.py new file mode 100644 index 0000000000..3ed39e0cf8 --- /dev/null +++ b/src/sdk/pynni/nni/metis_tuner/Regression_GMM/CreateModel.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import os +import sys +from operator import itemgetter + +import sklearn.mixture as mm + +sys.path.insert(1, os.path.join(sys.path[0], '..')) + + +def create_model(samples_x, samples_y_aggregation, percentage_goodbatch=0.34): + ''' + Create the Gaussian Mixture Model + ''' + samples = [samples_x[i] + [samples_y_aggregation[i]] for i in range(0, len(samples_x))] + + # Sorts so that we can get the top samples + samples = sorted(samples, key=itemgetter(-1)) + samples_goodbatch_size = int(len(samples) * percentage_goodbatch) + samples_goodbatch = samples[0:samples_goodbatch_size] + samples_badbatch = samples[samples_goodbatch_size:] + + samples_x_goodbatch = [sample_goodbatch[0:-1] for sample_goodbatch in samples_goodbatch] + #samples_y_goodbatch = [sample_goodbatch[-1] for sample_goodbatch in samples_goodbatch] + samples_x_badbatch = [sample_badbatch[0:-1] for sample_badbatch in samples_badbatch] + + # === Trains GMM clustering models === # + #sys.stderr.write("[%s] Train GMM's GMM model\n" % (os.path.basename(__file__))) + bgmm_goodbatch = mm.BayesianGaussianMixture(n_components=max(1, samples_goodbatch_size - 1)) + bad_n_components = max(1, len(samples_x) - samples_goodbatch_size - 1) + bgmm_badbatch = mm.BayesianGaussianMixture(n_components=bad_n_components) + bgmm_goodbatch.fit(samples_x_goodbatch) + bgmm_badbatch.fit(samples_x_badbatch) + + model = {} + model['clusteringmodel_good'] = bgmm_goodbatch + model['clusteringmodel_bad'] = bgmm_badbatch + return model + \ No newline at end of file diff --git a/src/sdk/pynni/nni/metis_tuner/Regression_GMM/Selection.py b/src/sdk/pynni/nni/metis_tuner/Regression_GMM/Selection.py new file mode 100644 index 0000000000..4507e30886 --- /dev/null +++ b/src/sdk/pynni/nni/metis_tuner/Regression_GMM/Selection.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import os +import random +import sys + +import nni.metis_tuner.lib_acquisition_function as lib_acquisition_function +import nni.metis_tuner.lib_constraint_summation as lib_constraint_summation +import nni.metis_tuner.lib_data as lib_data + +sys.path.insert(1, os.path.join(sys.path[0], '..')) + + +CONSTRAINT_LOWERBOUND = None +CONSTRAINT_UPPERBOUND = None +CONSTRAINT_PARAMS_IDX = [] + + +def _ratio_scores(parameters_value, clusteringmodel_gmm_good, clusteringmodel_gmm_bad): + ''' + The ratio is smaller the better + ''' + ratio = clusteringmodel_gmm_good.score([parameters_value]) / clusteringmodel_gmm_bad.score([parameters_value]) + sigma = 0 + return ratio, sigma + +def selection_r(x_bounds, + x_types, + clusteringmodel_gmm_good, + clusteringmodel_gmm_bad, + num_starting_points=100, + minimize_constraints_fun=None): + ''' + Call selection + ''' + minimize_starting_points = [lib_data.rand(x_bounds, x_types)\ + for i in range(0, num_starting_points)] + outputs = selection(x_bounds, x_types, + clusteringmodel_gmm_good, + clusteringmodel_gmm_bad, + minimize_starting_points, + minimize_constraints_fun) + return outputs + +def selection(x_bounds, + x_types, + clusteringmodel_gmm_good, + clusteringmodel_gmm_bad, + minimize_starting_points, + minimize_constraints_fun=None): + ''' + Select the lowest mu value + ''' + results = lib_acquisition_function.next_hyperparameter_lowest_mu(\ + _ratio_scores, [clusteringmodel_gmm_good, clusteringmodel_gmm_bad],\ + x_bounds, x_types, minimize_starting_points, \ + minimize_constraints_fun=minimize_constraints_fun) + + return results + +def _rand_with_constraints(x_bounds, x_types): + ''' + Random generate the variable with constraints + ''' + outputs = None + x_bounds_withconstraints = [x_bounds[i] for i in CONSTRAINT_PARAMS_IDX] + x_types_withconstraints = [x_types[i] for i in CONSTRAINT_PARAMS_IDX] + x_val_withconstraints = lib_constraint_summation.rand(x_bounds_withconstraints, + x_types_withconstraints, + CONSTRAINT_LOWERBOUND, + CONSTRAINT_UPPERBOUND) + if x_val_withconstraints is not None: + outputs = [None] * len(x_bounds) + for i, _ in enumerate(CONSTRAINT_PARAMS_IDX): + outputs[CONSTRAINT_PARAMS_IDX[i]] = x_val_withconstraints[i] + for i, _ in enumerate(outputs): + if outputs[i] is None: + outputs[i] = random.randint(x_bounds[i][0], x_bounds[i][1]) + return outputs + +def _minimize_constraints_fun_summation(x): + ''' + Minimize constraints fun summation + ''' + summation = sum([x[i] for i in CONSTRAINT_PARAMS_IDX]) + return CONSTRAINT_UPPERBOUND >= summation >= CONSTRAINT_LOWERBOUND diff --git a/src/sdk/pynni/nni/metis_tuner/Regression_GMM/__init__.py b/src/sdk/pynni/nni/metis_tuner/Regression_GMM/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sdk/pynni/nni/metis_tuner/Regression_GP/CreateModel.py b/src/sdk/pynni/nni/metis_tuner/Regression_GP/CreateModel.py new file mode 100644 index 0000000000..c1d16475c8 --- /dev/null +++ b/src/sdk/pynni/nni/metis_tuner/Regression_GP/CreateModel.py @@ -0,0 +1,52 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import os +import sys +import numpy + +import sklearn.gaussian_process as gp + +sys.path.insert(1, os.path.join(sys.path[0], '..')) + + +def create_model(samples_x, samples_y_aggregation, + n_restarts_optimizer=250, is_white_kernel=False): + ''' + Trains GP regression model + ''' + kernel = gp.kernels.ConstantKernel(constant_value=1, + constant_value_bounds=(1e-12, 1e12)) * \ + gp.kernels.Matern(nu=1.5) + if is_white_kernel is True: + kernel += gp.kernels.WhiteKernel(noise_level=1, noise_level_bounds=(1e-12, 1e12)) + regressor = gp.GaussianProcessRegressor(kernel=kernel, + n_restarts_optimizer=n_restarts_optimizer, + normalize_y=True, + alpha=0) + regressor.fit(numpy.array(samples_x), numpy.array(samples_y_aggregation)) + + model = {} + model['model'] = regressor + model['kernel_prior'] = str(kernel) + model['kernel_posterior'] = str(regressor.kernel_) + model['model_loglikelihood'] = regressor.log_marginal_likelihood(regressor.kernel_.theta) + + return model diff --git a/src/sdk/pynni/nni/metis_tuner/Regression_GP/OutlierDetection.py b/src/sdk/pynni/nni/metis_tuner/Regression_GP/OutlierDetection.py new file mode 100644 index 0000000000..353c56f2b0 --- /dev/null +++ b/src/sdk/pynni/nni/metis_tuner/Regression_GP/OutlierDetection.py @@ -0,0 +1,99 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +import argparse, json, os, sys +from multiprocessing.dummy import Pool as ThreadPool + +import nni.metis_tuner.Regression_GP.CreateModel as gp_create_model +import nni.metis_tuner.Regression_GP.Prediction as gp_prediction +import nni.metis_tuner.lib_data as lib_data + +sys.path.insert(1, os.path.join(sys.path[0], '..')) + + +def _outlierDetection_threaded(inputs): + ''' + Detect the outlier + ''' + [samples_idx, samples_x, samples_y_aggregation] = inputs + sys.stderr.write("[%s] DEBUG: Evaluating %dth of %d samples\n"\ + % (os.path.basename(__file__), samples_idx + 1, len(samples_x))) + outlier = None + + # Create a diagnostic regression model which removes the sample that we want to evaluate + diagnostic_regressor_gp = gp_create_model.createModel(\ + samples_x[0:samples_idx] + samples_x[samples_idx + 1:],\ + samples_y_aggregation[0:samples_idx] + samples_y_aggregation[samples_idx + 1:]) + mu, sigma = gp_prediction.predict(samples_x[samples_idx], diagnostic_regressor_gp['model']) + + # 2.33 is the z-score for 98% confidence level + if abs(samples_y_aggregation[samples_idx] - mu) > (2.33 * sigma): + outlier = {"samples_idx": samples_idx, + "expected_mu": mu, + "expected_sigma": sigma, + "difference": abs(samples_y_aggregation[samples_idx] - mu) - (2.33 * sigma)} + return outlier + +def outlierDetection_threaded(samples_x, samples_y_aggregation): + ''' + Use Multi-thread to detect the outlier + ''' + outliers = [] + + threads_inputs = [[samples_idx, samples_x, samples_y_aggregation]\ + for samples_idx in range(0, len(samples_x))] + threads_pool = ThreadPool(min(4, len(threads_inputs))) + threads_results = threads_pool.map(_outlierDetection_threaded, threads_inputs) + threads_pool.close() + threads_pool.join() + + for threads_result in threads_results: + if threads_result is not None: + outliers.append(threads_result) + else: + print("error here.") + + outliers = None if len(outliers) == 0 else outliers + return outliers + +def outlierDetection(samples_x, samples_y_aggregation): + ''' + ''' + outliers = [] + for samples_idx in range(0, len(samples_x)): + #sys.stderr.write("[%s] DEBUG: Evaluating %d of %d samples\n" + # \ % (os.path.basename(__file__), samples_idx + 1, len(samples_x))) + diagnostic_regressor_gp = gp_create_model.createModel(\ + samples_x[0:samples_idx] + samples_x[samples_idx + 1:],\ + samples_y_aggregation[0:samples_idx] + samples_y_aggregation[samples_idx + 1:]) + mu, sigma = gp_prediction.predict(samples_x[samples_idx], + diagnostic_regressor_gp['model']) + # 2.33 is the z-score for 98% confidence level + if abs(samples_y_aggregation[samples_idx] - mu) > (2.33 * sigma): + outliers.append({"samples_idx": samples_idx, + "expected_mu": mu, + "expected_sigma": sigma, + "difference": abs(samples_y_aggregation[samples_idx] - mu) - (2.33 * sigma)}) + + outliers = None if len(outliers) == 0 else outliers + return outliers + + \ No newline at end of file diff --git a/src/sdk/pynni/nni/metis_tuner/Regression_GP/Prediction.py b/src/sdk/pynni/nni/metis_tuner/Regression_GP/Prediction.py new file mode 100644 index 0000000000..82d3d0353f --- /dev/null +++ b/src/sdk/pynni/nni/metis_tuner/Regression_GP/Prediction.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import os +import sys + +import numpy + +sys.path.insert(1, os.path.join(sys.path[0], '..')) + + +def predict(parameters_value, regressor_gp): + ''' + Predict by Gaussian Process Model + ''' + parameters_value = numpy.array(parameters_value).reshape(-1, len(parameters_value)) + mu, sigma = regressor_gp.predict(parameters_value, return_std=True) + + return mu[0], sigma[0] + \ No newline at end of file diff --git a/src/sdk/pynni/nni/metis_tuner/Regression_GP/Selection.py b/src/sdk/pynni/nni/metis_tuner/Regression_GP/Selection.py new file mode 100644 index 0000000000..9c8e384a3d --- /dev/null +++ b/src/sdk/pynni/nni/metis_tuner/Regression_GP/Selection.py @@ -0,0 +1,114 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import os +import random +import sys + +import nni.metis_tuner.lib_acquisition_function as lib_acquisition_function +import nni.metis_tuner.lib_constraint_summation as lib_constraint_summation +import nni.metis_tuner.lib_data as lib_data +import nni.metis_tuner.Regression_GP.Prediction as gp_prediction + +sys.path.insert(1, os.path.join(sys.path[0], '..')) + +CONSTRAINT_LOWERBOUND = None +CONSTRAINT_UPPERBOUND = None +CONSTRAINT_PARAMS_IDX = [] + + +def selection_r(acquisition_function, + samples_y_aggregation, + x_bounds, + x_types, + regressor_gp, + num_starting_points=100, + minimize_constraints_fun=None): + ''' + Selecte R value + ''' + minimize_starting_points = [lib_data.rand(x_bounds, x_types) \ + for i in range(0, num_starting_points)] + outputs = selection(acquisition_function, samples_y_aggregation, + x_bounds, x_types, regressor_gp, + minimize_starting_points, + minimize_constraints_fun=minimize_constraints_fun) + + return outputs + +def selection(acquisition_function, + samples_y_aggregation, + x_bounds, x_types, + regressor_gp, + minimize_starting_points, + minimize_constraints_fun=None): + ''' + selection + ''' + outputs = None + + sys.stderr.write("[%s] Exercise \"%s\" acquisition function\n" \ + % (os.path.basename(__file__), acquisition_function)) + + if acquisition_function == "ei": + outputs = lib_acquisition_function.next_hyperparameter_expected_improvement(\ + gp_prediction.predict, [regressor_gp], x_bounds, x_types, \ + samples_y_aggregation, minimize_starting_points, \ + minimize_constraints_fun=minimize_constraints_fun) + elif acquisition_function == "lc": + outputs = lib_acquisition_function.next_hyperparameter_lowest_confidence(\ + gp_prediction.predict, [regressor_gp], x_bounds, x_types,\ + minimize_starting_points, minimize_constraints_fun=minimize_constraints_fun) + elif acquisition_function == "lm": + outputs = lib_acquisition_function.next_hyperparameter_lowest_mu(\ + gp_prediction.predict, [regressor_gp], x_bounds, x_types,\ + minimize_starting_points, minimize_constraints_fun=minimize_constraints_fun) + return outputs + +def _rand_with_constraints(x_bounds, x_types): + ''' + Random generate with constraints + ''' + outputs = None + + x_bounds_withconstraints = [x_bounds[i] for i in CONSTRAINT_PARAMS_IDX] + x_types_withconstraints = [x_types[i] for i in CONSTRAINT_PARAMS_IDX] + x_val_withconstraints = lib_constraint_summation.rand(x_bounds_withconstraints, + x_types_withconstraints, + CONSTRAINT_LOWERBOUND, + CONSTRAINT_UPPERBOUND) + if x_val_withconstraints is not None: + outputs = [None] * len(x_bounds) + + for i, _ in enumerate(CONSTRAINT_PARAMS_IDX): + outputs[CONSTRAINT_PARAMS_IDX[i]] = x_val_withconstraints[i] + + for i, _ in enumerate(outputs): + if outputs[i] is None: + outputs[i] = random.randint(x_bounds[i][0], x_bounds[i][1]) + return outputs + + +def _minimize_constraints_fun_summation(x): + ''' + Minimize the constraints fun summation + ''' + summation = sum([x[i] for i in CONSTRAINT_PARAMS_IDX]) + return CONSTRAINT_UPPERBOUND >= summation >= CONSTRAINT_LOWERBOUND diff --git a/src/sdk/pynni/nni/metis_tuner/Regression_GP/__init__.py b/src/sdk/pynni/nni/metis_tuner/Regression_GP/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sdk/pynni/nni/metis_tuner/lib_acquisition_function.py b/src/sdk/pynni/nni/metis_tuner/lib_acquisition_function.py new file mode 100644 index 0000000000..1caf8c814a --- /dev/null +++ b/src/sdk/pynni/nni/metis_tuner/lib_acquisition_function.py @@ -0,0 +1,202 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import sys +import numpy + +from scipy.stats import norm +from scipy.optimize import minimize + +import nni.metis_tuner.lib_data as lib_data + + +def next_hyperparameter_expected_improvement(fun_prediction, + fun_prediction_args, + x_bounds, x_types, + samples_y_aggregation, + minimize_starting_points, + minimize_constraints_fun=None): + ''' + "Expected Improvement" acquisition function + ''' + best_x = None + best_acquisition_value = None + x_bounds_minmax = [[i[0], i[-1]] for i in x_bounds] + x_bounds_minmax = numpy.array(x_bounds_minmax) + + for starting_point in numpy.array(minimize_starting_points): + res = minimize(fun=_expected_improvement, + x0=starting_point.reshape(1, -1), + bounds=x_bounds_minmax, + method="L-BFGS-B", + args=(fun_prediction, + fun_prediction_args, + x_bounds, + x_types, + samples_y_aggregation, + minimize_constraints_fun)) + + if (best_acquisition_value is None) or \ + (res.fun < best_acquisition_value): + res.x = numpy.ndarray.tolist(res.x) + res.x = lib_data.match_val_type(res.x, x_bounds, x_types) + if (minimize_constraints_fun is None) or \ + (minimize_constraints_fun(res.x) is True): + best_acquisition_value = res.fun + best_x = res.x + + outputs = None + if best_x is not None: + mu, sigma = fun_prediction(best_x, *fun_prediction_args) + outputs = {'hyperparameter': best_x, 'expected_mu': mu, + 'expected_sigma': sigma, 'acquisition_func': "ei"} + + return outputs + +def _expected_improvement(x, fun_prediction, fun_prediction_args, + x_bounds, x_types, samples_y_aggregation, + minimize_constraints_fun): + # This is only for step-wise optimization + x = lib_data.match_val_type(x, x_bounds, x_types) + + expected_improvement = sys.maxsize + if (minimize_constraints_fun is None) or (minimize_constraints_fun(x) is True): + mu, sigma = fun_prediction(x, *fun_prediction_args) + + loss_optimum = min(samples_y_aggregation) + scaling_factor = -1 + + # In case sigma equals zero + with numpy.errstate(divide="ignore"): + Z = scaling_factor * (mu - loss_optimum) / sigma + expected_improvement = scaling_factor * (mu - loss_optimum) * \ + norm.cdf(Z) + sigma * norm.pdf(Z) + expected_improvement = 0.0 if sigma == 0.0 else expected_improvement + + # We want expected_improvement to be as large as possible + # (i.e., as small as possible for minimize(...)) + expected_improvement = -1 * expected_improvement + return expected_improvement + + +def next_hyperparameter_lowest_confidence(fun_prediction, + fun_prediction_args, + x_bounds, x_types, + minimize_starting_points, + minimize_constraints_fun=None): + ''' + "Lowest Confidence" acquisition function + ''' + best_x = None + best_acquisition_value = None + x_bounds_minmax = [[i[0], i[-1]] for i in x_bounds] + x_bounds_minmax = numpy.array(x_bounds_minmax) + + for starting_point in numpy.array(minimize_starting_points): + res = minimize(fun=_lowest_confidence, + x0=starting_point.reshape(1, -1), + bounds=x_bounds_minmax, + method="L-BFGS-B", + args=(fun_prediction, + fun_prediction_args, + x_bounds, + x_types, + minimize_constraints_fun)) + + if (best_acquisition_value) is None or (res.fun < best_acquisition_value): + res.x = numpy.ndarray.tolist(res.x) + res.x = lib_data.match_val_type(res.x, x_bounds, x_types) + if (minimize_constraints_fun is None) or (minimize_constraints_fun(res.x) is True): + best_acquisition_value = res.fun + best_x = res.x + + outputs = None + if best_x is not None: + mu, sigma = fun_prediction(best_x, *fun_prediction_args) + outputs = {'hyperparameter': best_x, 'expected_mu': mu, + 'expected_sigma': sigma, 'acquisition_func': "lc"} + return outputs + +def _lowest_confidence(x, fun_prediction, fun_prediction_args, + x_bounds, x_types, minimize_constraints_fun): + # This is only for step-wise optimization + x = lib_data.match_val_type(x, x_bounds, x_types) + + ci = sys.maxsize + if (minimize_constraints_fun is None) or (minimize_constraints_fun(x) is True): + mu, sigma = fun_prediction(x, *fun_prediction_args) + ci = (sigma * 1.96 * 2) / mu + # We want ci to be as large as possible + # (i.e., as small as possible for minimize(...), + # because this would mean lowest confidence + ci = -1 * ci + + return ci + + +def next_hyperparameter_lowest_mu(fun_prediction, + fun_prediction_args, + x_bounds, x_types, + minimize_starting_points, + minimize_constraints_fun=None): + ''' + "Lowest Mu" acquisition function + ''' + best_x = None + best_acquisition_value = None + x_bounds_minmax = [[i[0], i[-1]] for i in x_bounds] + x_bounds_minmax = numpy.array(x_bounds_minmax) + + for starting_point in numpy.array(minimize_starting_points): + res = minimize(fun=_lowest_mu, + x0=starting_point.reshape(1, -1), + bounds=x_bounds_minmax, + method="L-BFGS-B", + args=(fun_prediction, fun_prediction_args, \ + x_bounds, x_types, minimize_constraints_fun)) + + if (best_acquisition_value is None) or (res.fun < best_acquisition_value): + res.x = numpy.ndarray.tolist(res.x) + res.x = lib_data.match_val_type(res.x, x_bounds, x_types) + if (minimize_constraints_fun is None) or (minimize_constraints_fun(res.x) is True): + best_acquisition_value = res.fun + best_x = res.x + + outputs = None + if best_x is not None: + mu, sigma = fun_prediction(best_x, *fun_prediction_args) + outputs = {'hyperparameter': best_x, 'expected_mu': mu, + 'expected_sigma': sigma, 'acquisition_func': "lm"} + return outputs + + +def _lowest_mu(x, fun_prediction, fun_prediction_args, + x_bounds, x_types, minimize_constraints_fun): + ''' + Calculate the lowest mu + ''' + # This is only for step-wise optimization + x = lib_data.match_val_type(x, x_bounds, x_types) + + mu = sys.maxsize + if (minimize_constraints_fun is None) or (minimize_constraints_fun(x) is True): + mu, _ = fun_prediction(x, *fun_prediction_args) + return mu + \ No newline at end of file diff --git a/src/sdk/pynni/nni/metis_tuner/lib_constraint_summation.py b/src/sdk/pynni/nni/metis_tuner/lib_constraint_summation.py new file mode 100644 index 0000000000..1e9daaee95 --- /dev/null +++ b/src/sdk/pynni/nni/metis_tuner/lib_constraint_summation.py @@ -0,0 +1,116 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import math +import random + +from operator import itemgetter + + +def check_feasibility(x_bounds, lowerbound, upperbound): + ''' + This can have false positives. + For examples, parameters can only be 0 or 5, and the summation constraint is between 6 and 7. + ''' + # x_bounds should be sorted, so even for "discrete_int" type, + # the smallest and the largest number should the first and the last element + x_bounds_lowerbound = sum([x_bound[0] for x_bound in x_bounds]) + x_bounds_upperbound = sum([x_bound[-1] for x_bound in x_bounds]) + + # return ((x_bounds_lowerbound <= lowerbound) and (x_bounds_upperbound >= lowerbound)) or \ + # ((x_bounds_lowerbound <= upperbound) and (x_bounds_upperbound >= upperbound)) + return (x_bounds_lowerbound <= lowerbound <= x_bounds_upperbound) or \ + (x_bounds_lowerbound <= upperbound <= x_bounds_upperbound) + +def rand(x_bounds, x_types, lowerbound, upperbound, max_retries=100): + ''' + Key idea is that we try to move towards upperbound, by randomly choose one + value for each parameter. However, for the last parameter, + we need to make sure that its value can help us get above lowerbound + ''' + outputs = None + + if check_feasibility(x_bounds, lowerbound, upperbound) is True: + # Order parameters by their range size. We want the smallest range first, + # because the corresponding parameter has less numbers to choose from + x_idx_sorted = [] + for i, _ in enumerate(x_bounds): + if x_types[i] == "discrete_int": + x_idx_sorted.append([i, len(x_bounds[i])]) + elif (x_types[i] == "range_int") or (x_types[i] == "range_continuous"): + x_idx_sorted.append([i, math.floor(x_bounds[i][1] - x_bounds[i][0])]) + x_idx_sorted = sorted(x_idx_sorted, key=itemgetter(1)) + + for _ in range(max_retries): + budget_allocated = 0 + outputs = [None] * len(x_bounds) + + for i, _ in enumerate(x_idx_sorted): + x_idx = x_idx_sorted[i][0] + # The amount of unallocated space that we have + budget_max = upperbound - budget_allocated + # NOT the Last x that we need to assign a random number + if i < (len(x_idx_sorted) - 1): + if x_bounds[x_idx][0] <= budget_max: + if x_types[x_idx] == "discrete_int": + # Note the valid integer + temp = [] + for j in x_bounds[x_idx]: + if j <= budget_max: + temp.append(j) + # Randomly pick a number from the integer array + if temp: + outputs[x_idx] = temp[random.randint(0, len(temp) - 1)] + + elif (x_types[x_idx] == "range_int") or \ + (x_types[x_idx] == "range_continuous"): + outputs[x_idx] = random.randint(x_bounds[x_idx][0], + min(x_bounds[x_idx][-1], budget_max)) + + else: + # The last x that we need to assign a random number + randint_lowerbound = lowerbound - budget_allocated + randint_lowerbound = 0 if randint_lowerbound < 0 else randint_lowerbound + + # This check: + # is our smallest possible value going to overflow the available budget space, + # and is our largest possible value going to underflow the lower bound + if (x_bounds[x_idx][0] <= budget_max) and \ + (x_bounds[x_idx][-1] >= randint_lowerbound): + if x_types[x_idx] == "discrete_int": + temp = [] + for j in x_bounds[x_idx]: + # if (j <= budget_max) and (j >= randint_lowerbound): + if randint_lowerbound <= j <= budget_max: + temp.append(j) + if temp: + outputs[x_idx] = temp[random.randint(0, len(temp) - 1)] + elif (x_types[x_idx] == "range_int") or \ + (x_types[x_idx] == "range_continuous"): + outputs[x_idx] = random.randint(randint_lowerbound, + min(x_bounds[x_idx][1], budget_max)) + if outputs[x_idx] is None: + break + else: + budget_allocated += outputs[x_idx] + if None not in outputs: + break + return outputs + \ No newline at end of file diff --git a/src/sdk/pynni/nni/metis_tuner/lib_data.py b/src/sdk/pynni/nni/metis_tuner/lib_data.py new file mode 100644 index 0000000000..d24aeed678 --- /dev/null +++ b/src/sdk/pynni/nni/metis_tuner/lib_data.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import math +import random + + +def match_val_type(vals, vals_bounds, vals_types): + ''' + Update values in the array, to match their corresponding type + ''' + vals_new = [] + + for i, _ in enumerate(vals_types): + if vals_types[i] == "discrete_int": + # Find the closest integer in the array, vals_bounds + vals_new.append(min(vals_bounds[i], key=lambda x: abs(x - vals[i]))) + elif vals_types[i] == "range_int": + # Round down to the nearest integer + vals_new.append(math.floor(vals[i])) + elif vals_types[i] == "range_continuous": + # Don't do any processing for continous numbers + vals_new.append(vals[i]) + else: + return None + + return vals_new + + +def rand(x_bounds, x_types): + ''' + Random generate variable value within their bounds + ''' + outputs = [] + + for i, _ in enumerate(x_bounds): + if x_types[i] == "discrete_int": + temp = x_bounds[i][random.randint(0, len(x_bounds[i]) - 1)] + outputs.append(temp) + elif x_types[i] == "range_int": + temp = random.randint(x_bounds[i][0], x_bounds[i][1]) + outputs.append(temp) + elif x_types[i] == "range_continuous": + temp = random.uniform(x_bounds[i][0], x_bounds[i][1]) + outputs.append(temp) + else: + return None + + return outputs + \ No newline at end of file diff --git a/src/sdk/pynni/nni/metis_tuner/metis_tuner.py b/src/sdk/pynni/nni/metis_tuner/metis_tuner.py new file mode 100644 index 0000000000..a6a0608c25 --- /dev/null +++ b/src/sdk/pynni/nni/metis_tuner/metis_tuner.py @@ -0,0 +1,440 @@ +# Copyright (c) Microsoft Corporation +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, +# to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and +# to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING +# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import copy +import logging +import os +import random +import statistics +import sys + +from enum import Enum, unique +from multiprocessing.dummy import Pool as ThreadPool + +from nni.tuner import Tuner + +import nni.metis_tuner.lib_data as lib_data +import nni.metis_tuner.lib_constraint_summation as lib_constraint_summation +import nni.metis_tuner.Regression_GP.CreateModel as gp_create_model +import nni.metis_tuner.Regression_GP.Selection as gp_selection +import nni.metis_tuner.Regression_GP.Prediction as gp_prediction +import nni.metis_tuner.Regression_GP.OutlierDetection as gp_outlier_detection +import nni.metis_tuner.Regression_GMM.CreateModel as gmm_create_model +import nni.metis_tuner.Regression_GMM.Selection as gmm_selection + +logger = logging.getLogger("Metis_Tuner_AutoML") + +@unique +class OptimizeMode(Enum): + ''' + Optimize Mode class + ''' + Minimize = 'minimize' + Maximize = 'maximize' + + +NONE_TYPE = '' +CONSTRAINT_LOWERBOUND = None +CONSTRAINT_UPPERBOUND = None +CONSTRAINT_PARAMS_IDX = [] + + +class MetisTuner(Tuner): + ''' + Metis Tuner + ''' + + def __init__(self, optimize_mode="maximize", no_resampling=True, no_candidates=True, + selection_num_starting_points=10, cold_start_num=10): + ''' + optimize_mode: is a string that including two mode "maximize" and "minimize" + + no_resampling: True or False. Should Metis consider re-sampling as part of the search strategy? + If you are confident that the training dataset is noise-free, then you do not need re-sampling. + + no_candidates: True or False. Should Metis suggest parameters for the next benchmark? + If you do not plan to do more benchmarks, Metis can skip this step. + + selection_num_starting_points: how many times Metis should try to find the global optimal in the search space? + The higher the number, the longer it takes to output the solution. + + cold_start_num: Metis need some trial result to get cold start. when the number of trial result is less than + cold_start_num, Metis will randomly sample hyper-parameter for trial. + ''' + self.samples_x = [] + self.samples_y = [] + self.samples_y_aggregation = [] + self.space = None + self.no_resampling = no_resampling + self.no_candidates = no_candidates + self.optimize_mode = optimize_mode + self.key_order = [] + self.cold_start_num = cold_start_num + self.selection_num_starting_points = selection_num_starting_points + self.minimize_constraints_fun = None + self.minimize_starting_points = None + + + def update_search_space(self, search_space): + ''' + Update the self.x_bounds and self.x_types by the search_space.json + ''' + self.x_bounds = [[] for i in range(len(search_space))] + self.x_types = [NONE_TYPE for i in range(len(search_space))] + + for key in search_space: + self.key_order.append(key) + + key_type = {} + if isinstance(search_space, dict): + for key in search_space: + key_type = search_space[key]['_type'] + key_range = search_space[key]['_value'] + try: + idx = self.key_order.index(key) + except Exception as ex: + logger.exception(ex) + raise RuntimeError("The format search space contains \ + some key that didn't define in key_order.") + + if key_type == 'quniform': + if key_range[2] == 1: + self.x_bounds[idx] = [key_range[0], key_range[1]] + self.x_types[idx] = 'range_int' + else: + bounds = [] + for value in range(key_range[0], key_range[1], key_range[2]): + bounds.append(value) + self.x_bounds[idx] = bounds + self.x_types[idx] = 'discrete_int' + elif key_type == 'randint': + self.x_bounds[idx] = [0, key_range[0]] + self.x_types[idx] = 'range_int' + elif key_type == 'uniform': + self.x_bounds[idx] = [key_range[0], key_range[1]] + self.x_types[idx] = 'range_continuous' + elif key_type == 'choice': + self.x_bounds[idx] = key_range + self.x_types[idx] = 'discrete_int' + else: + logger.info("Metis Tuner doesn't support this kind of variable.") + raise RuntimeError("Metis Tuner doesn't support this kind of variable.") + else: + logger.info("The format of search space is not a dict.") + raise RuntimeError("The format of search space is not a dict.") + + self.minimize_starting_points = _rand_init(self.x_bounds, self.x_types, \ + self.selection_num_starting_points) + + + def _pack_output(self, init_parameter): + ''' + Pack the output + ''' + output = {} + for i, param in enumerate(init_parameter): + output[self.key_order[i]] = param + return output + + + def generate_parameters(self, parameter_id): + ''' + This function is for generate parameters to trial. + If the number of trial result is lower than cold start number, + metis will first random generate some parameters. + Otherwise, metis will choose the parameters by the Gussian Process Model and the Gussian Mixture Model. + ''' + if self.samples_x or len(self.samples_x) < self.cold_start_num: + init_parameter = _rand_init(self.x_bounds, self.x_types, 1)[0] + results = self._pack_output(init_parameter) + else: + results = self._selection(self.samples_x, self.samples_y_aggregation, self.samples_y, + self.x_bounds, self.x_types, + threshold_samplessize_resampling=(None if self.no_resampling is True else 50), + no_candidates=self.no_candidates, + minimize_starting_points=self.minimize_starting_points, + minimize_constraints_fun=self.minimize_constraints_fun) + + logger.info("Generate paramageters:\n", str(results)) + return results + + + def receive_trial_result(self, parameter_id, parameters, value): + ''' + Tuner receive result from trial. + An value example as follow: + value: 99.5% + ''' + value = self.extract_scalar_reward(value) + if self.optimize_mode == OptimizeMode.Maximize: + value = -value + + logger.info("Received trial result.") + logger.info("value is :", str(value)) + logger.info("parameter is : ", str(parameters)) + + # parse parameter to sample_x + sample_x = [0 for i in range(len(self.key_order))] + for key in parameters: + idx = self.key_order.index(key) + sample_x[idx] = parameters[key] + + # parse value to sample_y + temp_y = [] + if sample_x in self.samples_x: + idx = self.samples_x.index(sample_x) + temp_y = self.samples_y[idx] + temp_y.append(value) + self.samples_y[idx] = temp_y + + # calculate y aggregation + median = get_median(temp_y) + self.samples_y_aggregation[idx] = median + else: + self.samples_x.append(sample_x) + self.samples_y.append([value]) + + # calculate y aggregation + self.samples_y_aggregation.append([value]) + + + def _selection(self, samples_x, samples_y_aggregation, samples_y, + x_bounds, x_types, max_resampling_per_x=3, + threshold_samplessize_exploitation=12, + threshold_samplessize_resampling=50, no_candidates=False, + minimize_starting_points=None, minimize_constraints_fun=None): + + next_candidate = None + candidates = [] + samples_size_all = sum([len(i) for i in samples_y]) + samples_size_unique = len(samples_y) + + # ===== STEP 1: Compute the current optimum ===== + #sys.stderr.write("[%s] Predicting the optimal configuration from the current training dataset...\n" % (os.path.basename(__file__))) + gp_model = gp_create_model.create_model(samples_x, samples_y_aggregation) + lm_current = gp_selection.selection("lm", samples_y_aggregation, x_bounds, + x_types, gp_model['model'], + minimize_starting_points, + minimize_constraints_fun=minimize_constraints_fun) + if not lm_current: + return None + + if no_candidates is False: + candidates.append({'hyperparameter': lm_current['hyperparameter'], + 'expected_mu': lm_current['expected_mu'], + 'expected_sigma': lm_current['expected_sigma'], + 'reason': "exploitation_gp"}) + + # ===== STEP 2: Get recommended configurations for exploration ===== + #sys.stderr.write("[%s] Getting candidates for exploration...\n" + #% \(os.path.basename(__file__))) + results_exploration = gp_selection.selection("lc", samples_y_aggregation, + x_bounds, x_types, gp_model['model'], + minimize_starting_points, + minimize_constraints_fun=minimize_constraints_fun) + + if results_exploration is not None: + if _num_past_samples(results_exploration['hyperparameter'], samples_x, samples_y) == 0: + candidates.append({'hyperparameter': results_exploration['hyperparameter'], + 'expected_mu': results_exploration['expected_mu'], + 'expected_sigma': results_exploration['expected_sigma'], + 'reason': "exploration"}) + logger.info("DEBUG: 1 exploration candidate selected\n") + #sys.stderr.write("[%s] DEBUG: 1 exploration candidate selected\n" % (os.path.basename(__file__))) + else: + logger.info("DEBUG: No suitable exploration candidates were") + # sys.stderr.write("[%s] DEBUG: No suitable exploration candidates were \ + # found\n" % (os.path.basename(__file__))) + + # ===== STEP 3: Get recommended configurations for exploitation ===== + if samples_size_all >= threshold_samplessize_exploitation: + #sys.stderr.write("[%s] Getting candidates for exploitation...\n" % (os.path.basename(__file__))) + print("Getting candidates for exploitation...\n") + try: + gmm = gmm_create_model.create_model(samples_x, samples_y_aggregation) + results_exploitation = gmm_selection.selection(x_bounds, + x_types, + gmm['clusteringmodel_good'], + gmm['clusteringmodel_bad'], + minimize_starting_points, + minimize_constraints_fun=minimize_constraints_fun) + + if results_exploitation is not None: + if _num_past_samples(results_exploitation['hyperparameter'], samples_x, samples_y) == 0: + candidates.append({'hyperparameter': results_exploitation['hyperparameter'],\ + 'expected_mu': results_exploitation['expected_mu'],\ + 'expected_sigma': results_exploitation['expected_sigma'],\ + 'reason': "exploitation_gmm"}) + logger.info("DEBUG: 1 exploitation_gmm candidate selected\n") + else: + logger.info("DEBUG: No suitable exploitation_gmm candidates were found\n") + + except ValueError as exception: + # The exception: ValueError: Fitting the mixture model failed + # because some components have ill-defined empirical covariance + # (for instance caused by singleton or collapsed samples). + # Try to decrease the number of components, or increase reg_covar. + logger.info("DEBUG: No suitable exploitation_gmm candidates were found due to exception.") + logger.info(exception) + + # ===== STEP 4: Get a list of outliers ===== + if (threshold_samplessize_resampling is not None) and \ + (samples_size_unique >= threshold_samplessize_resampling): + logger.info("Getting candidates for re-sampling...\n") + results_outliers = gp_outlier_detection.outlierDetection_threaded(samples_x, samples_y_aggregation) + + if results_outliers is not None: + temp = len(candidates) + + for results_outlier in results_outliers: + if _num_past_samples(samples_x[results_outlier['samples_idx']], samples_x, samples_y) < max_resampling_per_x: + candidates.append({'hyperparameter': samples_x[results_outlier['samples_idx']],\ + 'expected_mu': results_outlier['expected_mu'],\ + 'expected_sigma': results_outlier['expected_sigma'],\ + 'reason': "resampling"}) + logger.info("DEBUG: %d re-sampling candidates selected\n") + else: + logger.info("DEBUG: No suitable resampling candidates were found\n") + + if candidates: + # ===== STEP 5: Compute the information gain of each candidate towards the optimum ===== + logger.info("Evaluating information gain of %d candidates...\n") + next_improvement = 0 + + threads_inputs = [[candidate, samples_x, samples_y, x_bounds, x_types, minimize_constraints_fun, minimize_starting_points] for candidate in candidates] + threads_pool = ThreadPool(4) + # Evaluate what would happen if we actually sample each candidate + threads_results = threads_pool.map(_calculate_lowest_mu_threaded, threads_inputs) + threads_pool.close() + threads_pool.join() + + for threads_result in threads_results: + if threads_result['expected_lowest_mu'] < lm_current['expected_mu']: + # Information gain + temp_improvement = threads_result['expected_lowest_mu'] - lm_current['expected_mu'] + + if next_improvement > temp_improvement: + logger.infor("DEBUG: \"next_candidate\" changed: \ + lowest mu might reduce from %f (%s) to %f (%s), %s\n" %\ + lm_current['expected_mu'], str(lm_current['hyperparameter']),\ + threads_result['expected_lowest_mu'],\ + str(threads_result['candidate']['hyperparameter']),\ + threads_result['candidate']['reason']) + + next_improvement = temp_improvement + next_candidate = threads_result['candidate'] + else: + # ===== STEP 6: If we have no candidates, randomly pick one ===== + logger.info("DEBUG: No candidates from exploration, exploitation,\ + and resampling. We will random a candidate for next_candidate\n") + + next_candidate = _rand_with_constraints(x_bounds, x_types) \ + if minimize_starting_points is None else minimize_starting_points[0] + next_candidate = lib_data.match_val_type(next_candidate, x_bounds, x_types) + expected_mu, expected_sigma = gp_prediction.predict(next_candidate, gp_model['model']) + next_candidate = {'hyperparameter': next_candidate, 'reason': "random", + 'expected_mu': expected_mu, 'expected_sigma': expected_sigma} + + outputs = self._pack_output(lm_current['hyperparameter']) + return outputs + + +def _rand_with_constraints(x_bounds, x_types): + outputs = None + x_bounds_withconstraints = [x_bounds[i] for i in CONSTRAINT_PARAMS_IDX] + x_types_withconstraints = [x_types[i] for i in CONSTRAINT_PARAMS_IDX] + + x_val_withconstraints = lib_constraint_summation.rand(x_bounds_withconstraints,\ + x_types_withconstraints, CONSTRAINT_LOWERBOUND, CONSTRAINT_UPPERBOUND) + if not x_val_withconstraints: + outputs = [None] * len(x_bounds) + + for i, _ in enumerate(CONSTRAINT_PARAMS_IDX): + outputs[CONSTRAINT_PARAMS_IDX[i]] = x_val_withconstraints[i] + + for i, output in enumerate(outputs): + if not output: + outputs[i] = random.randint(x_bounds[i][0], x_bounds[i][1]) + return outputs + + +def _calculate_lowest_mu_threaded(inputs): + [candidate, samples_x, samples_y, x_bounds, x_types, minimize_constraints_fun, minimize_starting_points] = inputs + + sys.stderr.write("[%s] Evaluating information gain of %s (%s)...\n" % \ + (os.path.basename(__file__), candidate['hyperparameter'], candidate['reason'])) + outputs = {"candidate": candidate, "expected_lowest_mu": None} + + for expected_mu in [candidate['expected_mu'] + 1.96 * candidate['expected_sigma'], + candidate['expected_mu'] - 1.96 * candidate['expected_sigma']]: + temp_samples_x = copy.deepcopy(samples_x) + temp_samples_y = copy.deepcopy(samples_y) + + try: + idx = temp_samples_x.index(candidate['hyperparameter']) + # This handles the case of re-sampling a potential outlier + temp_samples_y[idx].append(expected_mu) + except ValueError: + temp_samples_x.append(candidate['hyperparameter']) + temp_samples_y.append([expected_mu]) + + # Aggregates multiple observation of the sample sampling points + temp_y_aggregation = [statistics.median(temp_sample_y) for temp_sample_y in temp_samples_y] + temp_gp = gp_create_model.create_model(temp_samples_x, temp_y_aggregation) + temp_results = gp_selection.selection("lm", temp_y_aggregation, + x_bounds, x_types, temp_gp['model'], + minimize_starting_points, + minimize_constraints_fun=minimize_constraints_fun) + + if outputs["expected_lowest_mu"] is None or outputs["expected_lowest_mu"] > temp_results['expected_mu']: + outputs["expected_lowest_mu"] = temp_results['expected_mu'] + + return outputs + + +def _num_past_samples(x, samples_x, samples_y): + try: + idx = samples_x.index(x) + return len(samples_y[idx]) + except ValueError: + logger.info("x not in sample_x") + return 0 + + +def _rand_init(x_bounds, x_types, selection_num_starting_points): + ''' + Random sample some init seed within bounds. + ''' + return [lib_data.rand(x_bounds, x_types) for i \ + in range(0, selection_num_starting_points)] + + +def get_median(temp_list): + ''' + Return median + ''' + num = len(temp_list) + temp_list.sort() + print(temp_list) + if num % 2 == 0: + median = (temp_list[int(num/2)] + temp_list[int(num/2) - 1]) / 2 + else: + median = temp_list[int(num/2)] + return median diff --git a/src/sdk/pynni/nni/metis_tuner/requirments.txt b/src/sdk/pynni/nni/metis_tuner/requirments.txt new file mode 100644 index 0000000000..044bdd7586 --- /dev/null +++ b/src/sdk/pynni/nni/metis_tuner/requirments.txt @@ -0,0 +1 @@ +sklearn \ No newline at end of file diff --git a/src/sdk/pynni/requirements.txt b/src/sdk/pynni/requirements.txt index 89de05d1c4..3adfb06cf5 100644 --- a/src/sdk/pynni/requirements.txt +++ b/src/sdk/pynni/requirements.txt @@ -4,4 +4,7 @@ json_tricks # hyperopt tuner numpy scipy -hyperopt \ No newline at end of file +hyperopt + +# metis tuner +sklearn diff --git a/tools/nni_cmd/config_schema.py b/tools/nni_cmd/config_schema.py index 0cc3824865..14c4bd3635 100644 --- a/tools/nni_cmd/config_schema.py +++ b/tools/nni_cmd/config_schema.py @@ -68,6 +68,16 @@ Optional('n_output_node'): int, }, Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999), +},{ + 'builtinTunerName': 'MetisTuner', + 'classArgs': { + Optional('optimize_mode'): Or('maximize', 'minimize'), + Optional('no_resampling'): bool, + Optional('no_candidates'): bool, + Optional('selection_num_starting_points'): int, + Optional('cold_start_num'): int, + }, + Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999), },{ 'codeDir': os.path.exists, 'classFileName': str,