diff --git a/README_zh_CN.md b/README_zh_CN.md
index 05b1273666..02cd6d3651 100644
--- a/README_zh_CN.md
+++ b/README_zh_CN.md
@@ -10,7 +10,7 @@
NNI (Neural Network Intelligence) 是自动机器学习(AutoML)的工具包。 它通过多种调优的算法来搜索最好的神经网络结构和(或)超参,并支持单机、本地多机、云等不同的运行环境。
-### **NNI [v1.0](https://github.com/Microsoft/nni/releases) 已发布! [ ](#nni-released-reminder)**
+### **NNI [v1.0](https://github.com/Microsoft/nni/blob/master/docs/zh_CN/Release_v1.0.md) 已发布! [ ](#nni-released-reminder)**
@@ -22,7 +22,7 @@ NNI (Neural Network Intelligence) 是自动机器学习(AutoML)的工具包
- 框架和库
+ 支持的框架和库
@@ -62,8 +62,8 @@ NNI (Neural Network Intelligence) 是自动机器学习(AutoML)的工具包
示例
- 网络结构搜索 Tuner
+ NAS Tuner
Network Morphism
ENAS
@@ -332,19 +332,30 @@ You can use these commands to get more information about the experiment
* [自定义 Tuner](docs/zh_CN/Tuner/CustomizeTuner.md)
* [实现定制的训练平台](docs/zh_CN/TrainingService/HowToImplementTrainingService.md)
-## **外部代码库**
-
-下面是一些贡献者为 NNI 提供的使用示例 谢谢可爱的贡献者! 欢迎越来越多的人加入我们!
-
-* 在 NNI 中运行 [ENAS](examples/tuners/enas_nni/README_zh_CN.md)
-* 在 NNI 中运行 [神经网络架构结构搜索](examples/trials/nas_cifar10/README_zh_CN.md)
-* [NNI 中的自动特征工程](examples/trials/auto-feature-engineering/README_zh_CN.md)
+## **其它代码库和参考**
+
+经作者许可的一些 NNI 用法示例和相关文档。
+
+* ### **外部代码库**
+
+ * 在 NNI 中运行 [ENAS](examples/tuners/enas_nni/README_zh_CN.md)
+ * 在 NNI 中运行 [神经网络架构结构搜索](examples/trials/nas_cifar10/README_zh_CN.md)
+ * [NNI 中的自动特征工程](examples/trials/auto-feature-engineering/README_zh_CN.md)
+ * 使用 NNI 的 [矩阵分解超参调优](https://github.com/microsoft/recommenders/blob/master/notebooks/04_model_select_and_optimize/nni_surprise_svd.ipynb)
+* ### **相关文章**
+
+ * [超参数优化的对比](docs/zh_CN/CommunitySharings/HpoComparision.md)
+ * [神经网络结构搜索的对比](docs/zh_CN/CommunitySharings/NasComparision.md)
+ * [并行化顺序算法:TPE](docs/zh_CN/CommunitySharings/ParallelizingTpeSearch.md)
+ * [使用 NNI 为 SVD 自动调参](docs/zh_CN/CommunitySharings/RecommendersSvd.md)
+ * [使用 NNI 为 SPTAG 自动调参](docs/zh_CN/CommunitySharings/SptagAutoTune.md)
+ * **博客** - [AutoML 工具(Advisor,NNI 与 Google Vizier)的对比](http://gaocegege.com/Blog/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0/katib-new#%E6%80%BB%E7%BB%93%E4%B8%8E%E5%88%86%E6%9E%90) 作者:[@gaocegege](https://github.com/gaocegege) - kubeflow/katib 的设计与实现的总结与分析章节
## **反馈**
-* 在 [Gitter](https://gitter.im/Microsoft/nni?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) 中参与讨论
-* 在 [Stack Overflow](https://stackoverflow.com/questions/tagged/nni?sort=Newest&edited=true) 上使用 NNI 标签提问
+* 在 [Gitter](https://gitter.im/Microsoft/nni?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) 中参与讨论。
* [在 GitHub 上提交问题](https://github.com/microsoft/nni/issues/new/choose)。
+* 在 [Stack Overflow](https://stackoverflow.com/questions/tagged/nni?sort=Newest&edited=true) 上使用 nni 标签提问。
## **许可协议**
diff --git a/azure-pipelines.yml b/azure-pipelines.yml
index f142de7bd3..1563e4a0ee 100644
--- a/azure-pipelines.yml
+++ b/azure-pipelines.yml
@@ -15,7 +15,7 @@ jobs:
displayName: 'Install nni toolkit via source code'
- script: |
python3 -m pip install flake8 --user
- IGNORE=./tools/nni_annotation/testcase/*:F821,./examples/trials/mnist-nas/*/mnist*.py:F821
+ IGNORE=./tools/nni_annotation/testcase/*:F821,./examples/trials/mnist-nas/*/mnist*.py:F821,./examples/trials/nas_cifar10/src/cifar10/general_child.py:F821
python3 -m flake8 . --count --per-file-ignores=$IGNORE --select=E9,F63,F72,F82 --show-source --statistics
displayName: 'Run flake8 tests to find Python syntax errors and undefined names'
- script: |
diff --git a/deployment/docker/Dockerfile b/deployment/docker/Dockerfile
index e4c9d50f28..471a5b9964 100644
--- a/deployment/docker/Dockerfile
+++ b/deployment/docker/Dockerfile
@@ -68,8 +68,8 @@ RUN python3 -m pip --no-cache-dir install Keras==2.1.6
#
# PyTorch
#
-RUN python3 -m pip --no-cache-dir install torch==0.4.1
-RUN python3 -m pip install torchvision==0.2.1
+RUN python3 -m pip --no-cache-dir install torch==1.2.0
+RUN python3 -m pip install torchvision==0.4.0
#
# sklearn 0.20.0
diff --git a/docs/en_US/AdvancedFeature/MultiPhase.md b/docs/en_US/AdvancedFeature/MultiPhase.md
index 831a6c7544..9dbc22edc1 100644
--- a/docs/en_US/AdvancedFeature/MultiPhase.md
+++ b/docs/en_US/AdvancedFeature/MultiPhase.md
@@ -8,8 +8,6 @@ Typically each trial job gets a single configuration (e.g., hyperparameters) fro
The above cases can be supported by the same feature, i.e., multi-phase execution. To support those cases, basically a trial job should be able to request multiple configurations from tuner. Tuner is aware of whether two configuration requests are from the same trial job or different ones. Also in multi-phase a trial job can report multiple final results.
-Note that, `nni.get_next_parameter()` and `nni.report_final_result()` should be called sequentially: __call the former one, then call the later one; and repeat this pattern__. If `nni.get_next_parameter()` is called multiple times consecutively, and then `nni.report_final_result()` is called once, the result is associated to the last configuration, which is retrieved from the last get_next_parameter call. So there is no result associated to previous get_next_parameter calls, and it may cause some multi-phase algorithm broken.
-
## Create multi-phase experiment
### Write trial code which leverages multi-phase:
@@ -23,6 +21,9 @@ It is pretty simple to use multi-phase in trial code, an example is shown below:
for i in range(5):
# get parameter from tuner
tuner_param = nni.get_next_parameter()
+ # nni.get_next_parameter returns None if there is no more hyper parameters can be generated by tuner.
+ if tuner_param is None:
+ break
# consume the params
# ...
@@ -32,6 +33,10 @@ It is pretty simple to use multi-phase in trial code, an example is shown below:
# ...
```
+In multi-phase experiments, at each time the API ```nni.get_next_parameter()``` is called, it returns a new hyper parameter generated by tuner, then the trail code consume this new hyper parameter and report final result of this hyper parameter. `nni.get_next_parameter()` and `nni.report_final_result()` should be called sequentially: __call the former one, then call the later one; and repeat this pattern__. If `nni.get_next_parameter()` is called multiple times consecutively, and then `nni.report_final_result()` is called once, the result is associated to the last configuration, which is retrieved from the last get_next_parameter call. So there is no result associated to previous get_next_parameter calls, and it may cause some multi-phase algorithm broken.
+
+Note that, ```nni.get_next_parameter``` returns None if there is no more hyper parameters can be generated by tuner.
+
__2. Experiment configuration__
To enable multi-phase, you should also add `multiPhase: true` in your experiment YAML configure file. If this line is not added, `nni.get_next_parameter()` would always return the same configuration.
diff --git a/docs/en_US/Tuner/BuiltinTuner.md b/docs/en_US/Tuner/BuiltinTuner.md
index d4f4fdfb09..c7d27fc5d1 100644
--- a/docs/en_US/Tuner/BuiltinTuner.md
+++ b/docs/en_US/Tuner/BuiltinTuner.md
@@ -20,6 +20,7 @@ Currently we support the following algorithms:
|[__Metis Tuner__](#MetisTuner)|Metis offers the following benefits when it comes to tuning parameters: While most tools only predict the optimal configuration, Metis gives you two outputs: (a) current prediction of optimal configuration, and (b) suggestion for the next trial. No more guesswork. While most tools assume training datasets do not have noisy data, Metis actually tells you if you need to re-sample a particular hyper-parameter. [Reference Paper](https://www.microsoft.com/en-us/research/publication/metis-robustly-tuning-tail-latencies-cloud-systems/)|
|[__BOHB__](#BOHB)|BOHB is a follow-up work of Hyperband. It targets the weakness of Hyperband that new configurations are generated randomly without leveraging finished trials. For the name BOHB, HB means Hyperband, BO means Bayesian Optimization. BOHB leverages finished trials by building multiple TPE models, a proportion of new configurations are generated through these models. [Reference Paper](https://arxiv.org/abs/1807.01774)|
|[__GP Tuner__](#GPTuner)|Gaussian Process Tuner is a sequential model-based optimization (SMBO) approach with Gaussian Process as the surrogate. [Reference Paper](https://papers.nips.cc/paper/4443-algorithms-for-hyper-parameter-optimization.pdf), [Github Repo](https://github.com/fmfn/BayesianOptimization)|
+|[__PPO Tuner__](#PPOTuner)|PPO Tuner is an Reinforcement Learning tuner based on PPO algorithm. [Reference Paper](https://arxiv.org/abs/1707.06347)|
## Usage of Built-in Tuners
@@ -38,7 +39,7 @@ Note: Please follow the format when you write your `config.yml` file. Some built
TPE, as a black-box optimization, can be used in various scenarios and shows good performance in general. Especially when you have limited computation resource and can only try a small number of trials. From a large amount of experiments, we could found that TPE is far better than Random Search. [Detailed Description](./HyperoptTuner.md)
-**Requirement of classArg**
+**Requirement of classArgs**
* **optimize_mode** (*maximize or minimize, optional, default = maximize*) - If 'maximize', the tuner will target to maximize metrics. If 'minimize', the tuner will target to minimize metrics.
@@ -66,7 +67,7 @@ tuner:
Random search is suggested when each trial does not take too long (e.g., each trial can be completed very soon, or early stopped by assessor quickly), and you have enough computation resource. Or you want to uniformly explore the search space. Random Search could be considered as baseline of search algorithm. [Detailed Description](./HyperoptTuner.md)
-**Requirement of classArg:**
+**Requirement of classArgs**
* **optimize_mode** (*maximize or minimize, optional, default = maximize*) - If 'maximize', the tuner will target to maximize metrics. If 'minimize', the tuner will target to minimize metrics.
@@ -91,7 +92,7 @@ tuner:
Anneal is suggested when each trial does not take too long, and you have enough computation resource(almost same with Random Search). Or the variables in search space could be sample from some prior distribution. [Detailed Description](./HyperoptTuner.md)
-**Requirement of classArg**
+**Requirement of classArgs**
* **optimize_mode** (*maximize or minimize, optional, default = maximize*) - If 'maximize', the tuner will target to maximize metrics. If 'minimize', the tuner will target to minimize metrics.
@@ -117,7 +118,7 @@ tuner:
Its requirement of computation resource is relatively high. Specifically, it requires large initial population to avoid falling into local optimum. If your trial is short or leverages assessor, this tuner is a good choice. And, it is more suggested when your trial code supports weight transfer, that is, the trial could inherit the converged weights from its parent(s). This can greatly speed up the training progress. [Detailed Description](./EvolutionTuner.md)
-**Requirement of classArg**
+**Requirement of classArgs**
* **optimize_mode** (*maximize or minimize, optional, default = maximize*) - If 'maximize', the tuner will target to maximize metrics. If 'minimize', the tuner will target to minimize metrics.
@@ -156,7 +157,7 @@ nnictl package install --name=SMAC
Similar to TPE, SMAC is also a black-box tuner which can be tried in various scenarios, and is suggested when computation resource is limited. It is optimized for discrete hyperparameters, thus, suggested when most of your hyperparameters are discrete. [Detailed Description](./SmacTuner.md)
-**Requirement of classArg**
+**Requirement of classArgs**
* **optimize_mode** (*maximize or minimize, optional, default = maximize*) - If 'maximize', the tuner will target to maximize metrics. If 'minimize', the tuner will target to minimize metrics.
@@ -243,7 +244,7 @@ tuner:
It is suggested when you have limited computation resource but have relatively large search space. It performs well in the scenario that intermediate result (e.g., accuracy) can reflect good or bad of final result (e.g., accuracy) to some extent. [Detailed Description](./HyperbandAdvisor.md)
-**Requirement of classArg**
+**Requirement of classArgs**
* **optimize_mode** (*maximize or minimize, optional, default = maximize*) - If 'maximize', the tuner will target to maximize metrics. If 'minimize', the tuner will target to minimize metrics.
* **R** (*int, optional, default = 60*) - the maximum budget given to a trial (could be the number of mini-batches or epochs) can be allocated to a trial. Each trial should use TRIAL_BUDGET to control how long it runs.
@@ -277,7 +278,7 @@ NetworkMorphism requires [PyTorch](https://pytorch.org/get-started/locally) and
It is suggested that you want to apply deep learning methods to your task (your own dataset) but you have no idea of how to choose or design a network. You modify the [example](https://github.com/Microsoft/nni/tree/master/examples/trials/network_morphism/cifar10/cifar10_keras.py) to fit your own dataset and your own data augmentation method. Also you can change the batch size, learning rate or optimizer. It is feasible for different tasks to find a good network architecture. Now this tuner only supports the computer vision domain. [Detailed Description](./NetworkmorphismTuner.md)
-**Requirement of classArg**
+**Requirement of classArgs**
* **optimize_mode** (*maximize or minimize, optional, default = maximize*) - If 'maximize', the tuner will target to maximize metrics. If 'minimize', the tuner will target to minimize metrics.
* **task** (*('cv'), optional, default = 'cv'*) - The domain of experiment, for now, this tuner only supports the computer vision(cv) domain.
@@ -307,13 +308,12 @@ tuner:
> Built-in Tuner Name: **MetisTuner**
-Note that the only acceptable types of search space are `choice`, `quniform`, `uniform` and `randint`.
-
+Note that the only acceptable types of search space are `quniform`, `uniform` and `randint` and numerical `choice`. Only numerical values are supported since the values will be used to evaluate the 'distance' between different points.
**Suggested scenario**
Similar to TPE and SMAC, Metis is a black-box tuner. If your system takes a long time to finish each trial, Metis is more favorable than other approaches such as random search. Furthermore, Metis provides guidance on the subsequent trial. Here is an [example](https://github.com/Microsoft/nni/tree/master/examples/trials/auto-gbdt/search_space_metis.json) about the use of Metis. User only need to send the final result like `accuracy` to tuner, by calling the NNI SDK. [Detailed Description](./MetisTuner.md)
-**Requirement of classArg**
+**Requirement of classArgs**
* **optimize_mode** (*'maximize' or 'minimize', optional, default = 'maximize'*) - If 'maximize', the tuner will target to maximize metrics. If 'minimize', the tuner will target to minimize metrics.
@@ -347,7 +347,7 @@ nnictl package install --name=BOHB
Similar to Hyperband, it is suggested when you have limited computation resource but have relatively large search space. It performs well in the scenario that intermediate result (e.g., accuracy) can reflect good or bad of final result (e.g., accuracy) to some extent. In this case, it may converges to a better configuration due to Bayesian optimization usage. [Detailed Description](./BohbAdvisor.md)
-**Requirement of classArg**
+**Requirement of classArgs**
* **optimize_mode** (*maximize or minimize, optional, default = maximize*) - If 'maximize', tuners will target to maximize metrics. If 'minimize', tuner will target to minimize metrics.
* **min_budget** (*int, optional, default = 1*) - The smallest budget assign to a trial job, (budget could be the number of mini-batches or epochs). Needs to be positive.
@@ -380,13 +380,13 @@ advisor:
> Built-in Tuner Name: **GPTuner**
-Note that the only acceptable types of search space are `choice`, `randint`, `uniform`, `quniform`, `loguniform`, `qloguniform`.
+Note that the only acceptable types of search space are `randint`, `uniform`, `quniform`, `loguniform`, `qloguniform`, and numerical `choice`. Only numerical values are supported since the values will be used to evaluate the 'distance' between different points.
**Suggested scenario**
As a strategy in Sequential Model-based Global Optimization(SMBO) algorithm, GP Tuner uses a proxy optimization problem (finding the maximum of the acquisition function) that, albeit still a hard problem, is cheaper (in the computational sense) and common tools can be employed. Therefore GP Tuner is most adequate for situations where the function to be optimized is a very expensive endeavor. GP can be used when the computation resource is limited. While GP Tuner has a computational cost that grows at *O(N^3)* due to the requirement of inverting the Gram matrix, so it's not suitable when lots of trials are needed. [Detailed Description](./GPTuner.md)
-**Requirement of classArg**
+**Requirement of classArgs**
* **optimize_mode** (*'maximize' or 'minimize', optional, default = 'maximize'*) - If 'maximize', the tuner will target to maximize metrics. If 'minimize', the tuner will target to minimize metrics.
* **utility** (*'ei', 'ucb' or 'poi', optional, default = 'ei'*) - The kind of utility function(acquisition function). 'ei', 'ucb' and 'poi' corresponds to 'Expected Improvement', 'Upper Confidence Bound' and 'Probability of Improvement' respectively.
@@ -415,3 +415,39 @@ tuner:
selection_num_warm_up: 100000
selection_num_starting_points: 250
```
+
+
+
+![](https://placehold.it/15/1589F0/000000?text=+) `PPO Tuner`
+
+> Built-in Tuner Name: **PPOTuner**
+
+Note that the only acceptable type of search space is `mutable_layer`. `optional_input_size` can only be 0, 1, or [0, 1].
+
+**Suggested scenario**
+
+PPOTuner is a Reinforcement Learning tuner based on PPO algorithm. When you are using NNI NAS interface in your trial code to do neural architecture search, PPOTuner is recommended. It has relatively high data efficiency but is suggested when you have large amount of computation resource. You could try it on very simple task, such as the [mnist-nas](https://github.com/microsoft/nni/tree/master/examples/trials/mnist-nas) example. [Detailed Description](./PPOTuner.md)
+
+**Requirement of classArgs**
+
+* **optimize_mode** (*'maximize' or 'minimize'*) - If 'maximize', the tuner will target to maximize metrics. If 'minimize', the tuner will target to minimize metrics.
+* **trials_per_update** (*int, optional, default = 20*) - The number of trials to be used for one update. This number is recommended to be larger than `trialConcurrency` and `trialConcurrency` be a aliquot devisor of `trials_per_update`. Note that trials_per_update should be divisible by minibatch_size.
+* **epochs_per_update** (*int, optional, default = 4*) - The number of epochs for one update.
+* **minibatch_size** (*int, optional, default = 4*) - Mini-batch size (i.e., number of trials for a mini-batch) for the update. Note that, trials_per_update should be divisible by minibatch_size.
+* **ent_coef** (*float, optional, default = 0.0*) - Policy entropy coefficient in the optimization objective.
+* **lr** (*float, optional, default = 3e-4*) - Learning rate of the model (lstm network), constant.
+* **vf_coef** (*float, optional, default = 0.5*) - Value function loss coefficient in the optimization objective.
+* **max_grad_norm** (*float, optional, default = 0.5*) - Gradient norm clipping coefficient.
+* **gamma** (*float, optional, default = 0.99*) - Discounting factor.
+* **lam** (*float, optional, default = 0.95*) - Advantage estimation discounting factor (lambda in the paper).
+* **cliprange** (*float, optional, default = 0.2*) - Cliprange in the PPO algorithm, constant.
+
+**Usage example**
+
+```yaml
+# config.yml
+tuner:
+ builtinTunerName: PPOTuner
+ classArgs:
+ optimize_mode: maximize
+```
\ No newline at end of file
diff --git a/docs/en_US/Tuner/GPTuner.md b/docs/en_US/Tuner/GPTuner.md
index 9ef49db2bb..0159c3d5c9 100644
--- a/docs/en_US/Tuner/GPTuner.md
+++ b/docs/en_US/Tuner/GPTuner.md
@@ -7,4 +7,6 @@ Bayesian optimization works by constructing a posterior distribution of function
GP Tuner is designed to minimize/maximize the number of steps required to find a combination of parameters that are close to the optimal combination. To do so, this method uses a proxy optimization problem (finding the maximum of the acquisition function) that, albeit still a hard problem, is cheaper (in the computational sense) and common tools can be employed. Therefore Bayesian Optimization is most adequate for situations where sampling the function to be optimized is a very expensive endeavor.
+Note that the only acceptable types of search space are `randint`, `uniform`, `quniform`, `loguniform`, `qloguniform`, and numerical `choice`.
+
This optimization approach is described in Section 3 of [Algorithms for Hyper-Parameter Optimization](https://papers.nips.cc/paper/4443-algorithms-for-hyper-parameter-optimization.pdf).
diff --git a/docs/en_US/Tuner/MetisTuner.md b/docs/en_US/Tuner/MetisTuner.md
index 7c0c8e3e37..d653796e01 100644
--- a/docs/en_US/Tuner/MetisTuner.md
+++ b/docs/en_US/Tuner/MetisTuner.md
@@ -15,6 +15,6 @@ It finds the global optimal point in the Gaussian Process space. This point repr
It identifies the next hyper-parameter candidate. This is achieved by inferring the potential information gain of exploration, exploitation, and re-sampling.
-Note that the only acceptable types of search space are `choice`, `quniform`, `uniform` and `randint`.
+Note that the only acceptable types of search space are `quniform`, `uniform` and `randint` and numerical `choice`.
More details can be found in our paper: https://www.microsoft.com/en-us/research/publication/metis-robustly-tuning-tail-latencies-cloud-systems/
\ No newline at end of file
diff --git a/docs/en_US/Tuner/PPOTuner.md b/docs/en_US/Tuner/PPOTuner.md
new file mode 100644
index 0000000000..6afdf89503
--- /dev/null
+++ b/docs/en_US/Tuner/PPOTuner.md
@@ -0,0 +1,20 @@
+PPO Tuner on NNI
+===
+
+## PPOTuner
+
+This is a tuner generally for NNI's NAS interface, it uses [ppo algorithm](https://arxiv.org/abs/1707.06347). The implementation inherits the main logic of the implementation [here](https://github.com/openai/baselines/tree/master/baselines/ppo2) (i.e., ppo2 from OpenAI), and is adapted for NAS scenario.
+
+It could successfully tune the [mnist-nas example](https://github.com/microsoft/nni/tree/master/examples/trials/mnist-nas), and has the following result:
+
+![](../../img/ppo_mnist.png)
+
+We also tune [the macro search space for image classification in the enas paper](https://github.com/microsoft/nni/tree/master/examples/trials/nas_cifar10) (with limited epoch number for each trial, i.e., 8 epochs), which is implemented using the NAS interface and tuned with PPOTuner. Use Figure 7 in the [enas paper](https://arxiv.org/pdf/1802.03268.pdf) to show how the search space looks like
+
+![](../../img/enas_search_space.png)
+
+The figure above is a chosen architecture, we use it to show how the search space looks like. Each square is a layer whose operation can be chosen from 6 operations. Each dash line is a skip connection, each square layer could choose 0 or 1 skip connection getting the output of a previous layer. __Note that__ in original macro search space each square layer could choose any number of skip connections, while in our implementation it is only allowed to choose 0 or 1.
+
+The result is shown in figure below (with the experiment config [here](https://github.com/microsoft/nni/blob/master/examples/trials/nas_cifar10/config_ppo.yml)):
+
+![](../../img/ppo_cifar10.png)
diff --git a/docs/en_US/Tutorial/ExperimentConfig.md b/docs/en_US/Tutorial/ExperimentConfig.md
index 0f8970bd3d..d610d25302 100644
--- a/docs/en_US/Tutorial/ExperimentConfig.md
+++ b/docs/en_US/Tutorial/ExperimentConfig.md
@@ -1,8 +1,8 @@
# Experiment config reference
-A config file is needed when create an experiment, the path of the config file is provide to nnictl.
-The config file is written in YAML format, and need to be written correctly.
-This document describes the rule to write config file, and will provide some examples and templates.
+A config file is needed when creating an experiment. The path of the config file is provided to `nnictl`.
+The config file is in YAML format.
+This document describes the rules to write the config file, and provides some examples and templates.
- [Experiment config reference](#Experiment-config-reference)
- [Template](#Template)
@@ -35,7 +35,7 @@ tuner:
classArgs:
#choice: maximize, minimize
optimize_mode:
- gpuNum:
+ gpuIndices:
trial:
command:
codeDir:
@@ -71,14 +71,13 @@ tuner:
classArgs:
#choice: maximize, minimize
optimize_mode:
- gpuNum:
+ gpuIndices:
assessor:
#choice: Medianstop
builtinAssessorName:
classArgs:
#choice: maximize, minimize
optimize_mode:
- gpuNum:
trial:
command:
codeDir:
@@ -113,14 +112,13 @@ tuner:
classArgs:
#choice: maximize, minimize
optimize_mode:
- gpuNum:
+ gpuIndices:
assessor:
#choice: Medianstop
builtinAssessorName:
classArgs:
#choice: maximize, minimize
optimize_mode:
- gpuNum:
trial:
command:
codeDir:
@@ -245,11 +243,11 @@ machineList:
* __builtinTunerName__ and __classArgs__
* __builtinTunerName__
- __builtinTunerName__ specifies the name of system tuner, NNI sdk provides four kinds of tuner, including {__TPE__, __Random__, __Anneal__, __Evolution__, __BatchTuner__, __GridSearch__}
+ __builtinTunerName__ specifies the name of system tuner, NNI sdk provides different tuners introduced [here](../Tuner/BuiltinTuner.md).
* __classArgs__
- __classArgs__ specifies the arguments of tuner algorithm. If the __builtinTunerName__ is in {__TPE__, __Random__, __Anneal__, __Evolution__}, user should set __optimize_mode__.
+ __classArgs__ specifies the arguments of tuner algorithm. Please refer to [this file](../Tuner/BuiltinTuner.md) for the configurable arguments of each built-in tuner.
* __codeDir__, __classFileName__, __className__ and __classArgs__
* __codeDir__
@@ -264,16 +262,16 @@ machineList:
__classArgs__ specifies the arguments of tuner algorithm.
- * __gpuNum__
-
- __gpuNum__ specifies the gpu number to run the tuner process. The value of this field should be a positive number. If the field is not set, NNI will not set `CUDA_VISIBLE_DEVICES` in script (that is, will not control the visibility of GPUs on trial command through `CUDA_VISIBLE_DEVICES`), and will not manage gpu resource.
+ * __gpuIndices__
- Note: users could only specify one way to set tuner, for example, set {tunerName, optimizationMode} or {tunerCommand, tunerCwd}, and could not set them both.
+ __gpuIndices__ specifies the gpus that can be used by the tuner process. Single or multiple GPU indices can be specified, multiple GPU indices are seperated by comma(,), such as `1` or `0,1,3`. If the field is not set, `CUDA_VISIBLE_DEVICES` will be '' in script, that is, no GPU is visible to tuner.
* __includeIntermediateResults__
If __includeIntermediateResults__ is true, the last intermediate result of the trial that is early stopped by assessor is sent to tuner as final result. The default value of __includeIntermediateResults__ is false.
+ Note: users could only use one way to specify tuner, either specifying `builtinTunerName` and `classArgs`, or specifying `codeDir`, `classFileName`, `className` and `classArgs`.
+
* __assessor__
* Description
@@ -282,7 +280,7 @@ machineList:
* __builtinAssessorName__ and __classArgs__
* __builtinAssessorName__
- __builtinAssessorName__ specifies the name of system assessor, NNI sdk provides one kind of assessor {__Medianstop__}
+ __builtinAssessorName__ specifies the name of built-in assessor, NNI sdk provides different assessors introducted [here](../Assessor/BuiltinAssessor.md).
* __classArgs__
__classArgs__ specifies the arguments of assessor algorithm
@@ -305,11 +303,39 @@ machineList:
__classArgs__ specifies the arguments of assessor algorithm.
- * __gpuNum__
+ Note: users could only use one way to specify assessor, either specifying `builtinAssessorName` and `classArgs`, or specifying `codeDir`, `classFileName`, `className` and `classArgs`. If users do not want to use assessor, assessor fileld should leave to empty.
+
+* __advisor__
+ * Description
- __gpuNum__ specifies the gpu number to run the assessor process. The value of this field should be a positive number.
+ __advisor__ specifies the advisor algorithm in the experiment, there are two kinds of ways to specify advisor. One way is to use advisor provided by NNI sdk, need to set __builtinAdvisorName__ and __classArgs__. Another way is to use users' own advisor file, and need to set __codeDirectory__, __classFileName__, __className__ and __classArgs__.
+ * __builtinAdvisorName__ and __classArgs__
+ * __builtinAdvisorName__
- Note: users' could only specify one way to set assessor, for example,set {assessorName, optimizationMode} or {assessorCommand, assessorCwd}, and users could not set them both.If users do not want to use assessor, assessor fileld should leave to empty.
+ __builtinAdvisorName__ specifies the name of a built-in advisor, NNI sdk provides [different advisors](../Tuner/BuiltinTuner.md).
+
+ * __classArgs__
+
+ __classArgs__ specifies the arguments of the advisor algorithm. Please refer to [this file](../Tuner/BuiltinTuner.md) for the configurable arguments of each built-in advisor.
+ * __codeDir__, __classFileName__, __className__ and __classArgs__
+ * __codeDir__
+
+ __codeDir__ specifies the directory of advisor code.
+ * __classFileName__
+
+ __classFileName__ specifies the name of advisor file.
+ * __className__
+
+ __className__ specifies the name of advisor class.
+ * __classArgs__
+
+ __classArgs__ specifies the arguments of advisor algorithm.
+
+ * __gpuIndices__
+
+ __gpuIndices__ specifies the gpus that can be used by the tuner process. Single or multiple GPU indices can be specified, multiple GPU indices are seperated by comma(,), such as `1` or `0,1,3`. If the field is not set, `CUDA_VISIBLE_DEVICES` will be '' in script, that is, no GPU is visible to tuner.
+
+ Note: users could only use one way to specify advisor, either specifying `builtinAdvisorName` and `classArgs`, or specifying `codeDir`, `classFileName`, `className` and `classArgs`.
* __trial(local, remote)__
@@ -560,7 +586,6 @@ machineList:
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
- gpuNum: 0
trial:
command: python3 mnist.py
codeDir: /nni/mnist
@@ -586,14 +611,12 @@ machineList:
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
- gpuNum: 0
assessor:
#choice: Medianstop
builtinAssessorName: Medianstop
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
- gpuNum: 0
trial:
command: python3 mnist.py
codeDir: /nni/mnist
@@ -620,7 +643,6 @@ machineList:
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
- gpuNum: 0
assessor:
codeDir: /nni/assessor
classFileName: myassessor.py
@@ -628,7 +650,6 @@ machineList:
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
- gpuNum: 0
trial:
command: python3 mnist.py
codeDir: /nni/mnist
@@ -656,7 +677,6 @@ machineList:
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
- gpuNum: 0
trial:
command: python3 mnist.py
codeDir: /nni/mnist
@@ -780,7 +800,6 @@ machineList:
builtinAssessorName: Medianstop
classArgs:
optimize_mode: maximize
- gpuNum: 0
trial:
codeDir: .
worker:
diff --git a/docs/en_US/Tutorial/Nnictl.md b/docs/en_US/Tutorial/Nnictl.md
index 5b5899e42d..6dac0bfc7f 100644
--- a/docs/en_US/Tutorial/Nnictl.md
+++ b/docs/en_US/Tutorial/Nnictl.md
@@ -10,6 +10,7 @@ nnictl support commands:
* [nnictl create](#create)
* [nnictl resume](#resume)
+* [nnictl view](#view)
* [nnictl stop](#stop)
* [nnictl update](#update)
* [nnictl trial](#trial)
@@ -104,6 +105,35 @@ Debug mode will disable version check function in Trialkeeper.
nnictl resume [experiment_id] --port 8088
```
+
+
+![](https://placehold.it/15/1589F0/000000?text=+) `nnictl view`
+
+* Description
+
+ You can use this command to view a stopped experiment.
+
+* Usage
+
+ ```bash
+ nnictl view [OPTIONS]
+ ```
+
+* Options
+
+ |Name, shorthand|Required|Default|Description|
+ |------|------|------ |------|
+ |id| True| |The id of the experiment you want to view|
+ |--port, -p| False| |Rest port of the experiment you want to view|
+
+* Example
+
+ > view an experiment with specified port 8088
+
+ ```bash
+ nnictl view [experiment_id] --port 8088
+ ```
+
![](https://placehold.it/15/1589F0/000000?text=+) `nnictl stop`
diff --git a/docs/en_US/Tutorial/SearchSpaceSpec.md b/docs/en_US/Tutorial/SearchSpaceSpec.md
index ea0d6bdf9e..b892a5e1e5 100644
--- a/docs/en_US/Tutorial/SearchSpaceSpec.md
+++ b/docs/en_US/Tutorial/SearchSpaceSpec.md
@@ -94,7 +94,7 @@ All types of sampling strategies and their parameter are listed here:
Known Limitations:
-* Note that Metis Tuner only supports numerical `choice` now
+* GP Tuner and Metis Tuner support only **numerical values** in search space(`choice` type values can be no-numeraical with other tuners, e.g. string values). Both GP Tuner and Metis Tuner use Gaussian Process Regressor(GPR). GPR make predictions based on a kernel function and the 'distance' between different points, it's hard to get the true distance between no-numerical values.
* Note that for nested search space:
diff --git a/docs/img/enas_search_space.png b/docs/img/enas_search_space.png
new file mode 100644
index 0000000000..9280cc37bb
Binary files /dev/null and b/docs/img/enas_search_space.png differ
diff --git a/docs/img/ppo_cifar10.png b/docs/img/ppo_cifar10.png
new file mode 100644
index 0000000000..b2061a07f6
Binary files /dev/null and b/docs/img/ppo_cifar10.png differ
diff --git a/docs/img/ppo_mnist.png b/docs/img/ppo_mnist.png
new file mode 100644
index 0000000000..3c5a00c176
Binary files /dev/null and b/docs/img/ppo_mnist.png differ
diff --git a/docs/zh_CN/AdvancedFeature/GeneralNasInterfaces.md b/docs/zh_CN/AdvancedFeature/GeneralNasInterfaces.md
index 71c05b2da9..c002614c27 100644
--- a/docs/zh_CN/AdvancedFeature/GeneralNasInterfaces.md
+++ b/docs/zh_CN/AdvancedFeature/GeneralNasInterfaces.md
@@ -1,6 +1,6 @@
-# 神经网络架构搜索的通用编程接口(测试版)
+# 神经网络架构搜索的 NNI 编程接口(NAS)
-** 这是一个测试中的功能,目前只实现了通用的 NAS 编程接口。 在随后的版本中会支持权重共享。*
+** 这是**实验性的功能**。 目前,仅实现了通用的 NAS 编程接口。 在随后的版本中会支持权重共享。*
自动化的神经网络架构(NAS)搜索在寻找更好的模型方面发挥着越来越重要的作用。 最近的研究工作证明了自动化 NAS 的可行性,并发现了一些超越手动设计和调整的模型。 代表算法有 [NASNet](https://arxiv.org/abs/1707.07012),[ENAS](https://arxiv.org/abs/1802.03268),[DARTS](https://arxiv.org/abs/1806.09055),[Network Morphism](https://arxiv.org/abs/1806.10282),以及 [Evolution](https://arxiv.org/abs/1703.01041) 等。 新的算法还在不断涌现。 然而,实现这些算法需要很大的工作量,且很难重用其它算法的代码库来实现。
diff --git a/docs/zh_CN/CommunitySharings/SptagAutoTune.md b/docs/zh_CN/CommunitySharings/SptagAutoTune.md
new file mode 100644
index 0000000000..456d47b621
--- /dev/null
+++ b/docs/zh_CN/CommunitySharings/SptagAutoTune.md
@@ -0,0 +1,7 @@
+# 使用 NNI 为 SPTAG 自动调参
+
+[SPTAG](https://github.com/microsoft/SPTAG) (Space Partition Tree And Graph) 是大规模向量的最近邻搜索的工具,由[微软研究院(MSR)](https://www.msra.cn/)和[微软必应团队](https://www.bing.com/)联合发布。
+
+此工具假设样本可以表示为向量,并且能通过 L2 或余弦算法来比较距离。 输入一个查询向量,会返回与其 L2 或余弦距离最小的一组向量。 SPTAG 提供了两种方法:kd-tree 与其的相关近邻图 (SPTAG-KDT),以及平衡 k-means 树与其的相关近邻图 (SPTAG-BKT)。 SPTAG-KDT 在索引构建效率上较好,而 SPTAG-BKT 在搜索高维度数据的精度上较好。
+
+在 SPTAG中,有几十个参数可以根据特定的场景或数据集进行调优。 NNI 是用来自动化调优这些参数的绝佳工具。 SPTAG 的作者尝试了使用 NNI 来进行自动调优,并轻松找到了性能较好的参数组合,并在 SPTAG [文档](https://github.com/microsoft/SPTAG/blob/master/docs/Parameters.md)中进行了分享。 参考此文档了解详细教程。
\ No newline at end of file
diff --git a/docs/zh_CN/CommunitySharings/community_sharings.rst b/docs/zh_CN/CommunitySharings/community_sharings.rst
index 98e6633473..509b68aa29 100644
--- a/docs/zh_CN/CommunitySharings/community_sharings.rst
+++ b/docs/zh_CN/CommunitySharings/community_sharings.rst
@@ -8,6 +8,7 @@
:maxdepth: 2
Recommenders 中使用 NNI
+ 使用 NNI 为 SPTAG 自动调参
神经网络结构搜索(NAS)的对比
超参调优算法的对比
TPE 的并行优化
diff --git a/docs/zh_CN/Release_v1.0.md b/docs/zh_CN/Release_v1.0.md
new file mode 100644
index 0000000000..6bec474194
--- /dev/null
+++ b/docs/zh_CN/Release_v1.0.md
@@ -0,0 +1,29 @@
+
+
+
+
+从 2018 年 9 月到 2019 年 9 月,我们一直在前进 ...
+
+**好消息!** NNI v1.0 带来了更强的**伸缩性**和**易用性**。 基于各种[调参算法](./Tuner/BuiltinTuner.md),NNI 已经支持了超参调优,神经网络结构搜索,自动特侦工程等,对算法工程师非常有用的功能。除此之外,NNI v1.0 带来了大量的改进,包括调优算法的优化,[Web 界面简化并更加直观](./Tutorial/WebUI.md),以及[更多样的平台](./TrainingService/SupportTrainingService.md)。 NNI 已经成为了一个更加智能的自动机器学习(AutoML)工具包。
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ **第一步**: 根据[教程](./Tutorial/Installation.md)安装 NNI v1.0。
+ **第二步**:找到 "Hello world" 示例,按照[教程](./Tutorial/QuickStart.md)入门。
+ **第三步**:熟悉 [Web 界面](./Tutorial/WebUI.md),用 NNI 进行自动机器学习!
+
+
+全自动工具极大地提高了调优过程的效率。 关于 v1.0 的更多细节,可参考 [Release 1.0](https://github.com/microsoft/nni/releases)。 关于进一步计划,可参考[路线图](https://github.com/microsoft/nni/wiki/Roadmap)。 此外,欢迎更多的参与者加入我们。可参考[如何贡献](./Tutorial/Contributing.md),来了解多种参与方法。
\ No newline at end of file
diff --git a/docs/zh_CN/SupportedFramework_Library.md b/docs/zh_CN/SupportedFramework_Library.md
index 8e54791a4d..481a921a4b 100644
--- a/docs/zh_CN/SupportedFramework_Library.md
+++ b/docs/zh_CN/SupportedFramework_Library.md
@@ -10,7 +10,7 @@
* [CIFAR-10](TrialExample/Cifar10Examples.md)
-* [TGS salt identification chanllenge](../../examples/trials/kaggle-tgs-salt/README.md)
+* [TGS salt identification chanllenge](../../examples/trials/kaggle-tgs-salt/README_zh_CN.md)
* [Network morphism](../../examples/trials/network_morphism/README_zh_CN.md)
diff --git a/docs/zh_CN/TrainingService/PaiMode.md b/docs/zh_CN/TrainingService/PaiMode.md
index 5ac564f15a..ee29276fd2 100644
--- a/docs/zh_CN/TrainingService/PaiMode.md
+++ b/docs/zh_CN/TrainingService/PaiMode.md
@@ -59,6 +59,35 @@ paiConfig:
* authFile
* 可选。在使用 pai 模式时,为私有 Docker 仓库设置认证文件,[见参考文档](https://github.com/microsoft/pai/blob/2ea69b45faa018662bc164ed7733f6fdbb4c42b3/docs/faq.md#q-how-to-use-private-docker-registry-job-image-when-submitting-an-openpai-job)。提供 authFile 的本地路径即可, NNI 会上传此文件。
+* portList
+
+ * 可选。 设置 OpenPAI 的 portList。指定了容器中使用的端口列表,[参考文档](https://github.com/microsoft/pai/blob/b2324866d0280a2d22958717ea6025740f71b9f0/docs/job_tutorial.md#specification)。
+ 示例如下:
+ portList:
+ - label: test
+ beginAt: 8080
+ portNumber: 2
+
+
+ 假设需要在 MNIST 示例中使用端口来运行 TensorBoard。 第一步是编写 `mnist.py` 的包装脚本 `launch_pai.sh`。
+
+ ```bash
+ export TENSORBOARD_PORT=PAI_PORT_LIST_${PAI_CURRENT_TASK_ROLE_NAME}_0_tensorboard
+ tensorboard --logdir . --port ${!TENSORBOARD_PORT} &
+ python3 mnist.py
+ ```
+
+ portList 的配置部分如下:
+
+ ```yaml
+ trial:
+ command: bash launch_pai.sh
+ portList:
+ - label: tensorboard
+ beginAt: 0
+ portNumber: 1
+ ```
+
完成并保存 NNI Experiment 配置文件后(例如可保存为:exp_pai.yml),运行以下命令:
nnictl create --config exp_pai.yml
diff --git a/docs/zh_CN/TrainingService/SupportTrainingService.md b/docs/zh_CN/TrainingService/SupportTrainingService.md
index 0e85797863..fbf6f6a2cd 100644
--- a/docs/zh_CN/TrainingService/SupportTrainingService.md
+++ b/docs/zh_CN/TrainingService/SupportTrainingService.md
@@ -36,4 +36,4 @@ TrainingService 的声明如下:
TrainingService 的父类有一些抽象函数,用户需要继承父类并实现所有这些抽象函数。
-有关如何实现 TrainingService 的更多信息,[参考这里](HowToImplementTrainingService.md)。
\ No newline at end of file
+有关如何实现 TrainingService 的更多信息,[参考这里](https://github.com/microsoft/nni/blob/master/docs/zh_CN/TrainingService/HowToImplementTrainingService.md)。
\ No newline at end of file
diff --git a/docs/zh_CN/Tuner/BuiltinTuner.md b/docs/zh_CN/Tuner/BuiltinTuner.md
index 6fe36f2dd0..132e5af66c 100644
--- a/docs/zh_CN/Tuner/BuiltinTuner.md
+++ b/docs/zh_CN/Tuner/BuiltinTuner.md
@@ -115,6 +115,12 @@ tuner:
此算法对计算资源的需求相对较高。 需要非常大的初始种群,以免落入局部最优中。 如果 Trial 时间很短,或者使用了 Assessor,就非常适合此算法。 如果 Trial 代码支持权重迁移,即每次 Trial 会从上一轮继承已经收敛的权重,建议使用此算法。 这会大大提高训练速度。 [详细说明](./EvolutionTuner.md)
+**参数**
+
+* **optimize_mode** (*maximize 或 minimize, 可选项, 默认值为 maximize*) - 如果为 'maximize',表示 Tuner 的目标是将指标最大化。 如果为 'minimize',表示 Tuner 的目标是将指标最小化。
+
+* **population_size** (*int 类型(大于 0), 可选项, 默认值为 20*) - 表示遗传 Tuner 中的种群(Trial 数量)。
+
**示例**
```yaml
@@ -123,6 +129,7 @@ tuner:
builtinTunerName: Evolution
classArgs:
optimize_mode: maximize
+ population_size: 100
```
diff --git a/docs/zh_CN/Tuner/GridsearchTuner.md b/docs/zh_CN/Tuner/GridsearchTuner.md
index a1003d19a5..90b4279c39 100644
--- a/docs/zh_CN/Tuner/GridsearchTuner.md
+++ b/docs/zh_CN/Tuner/GridsearchTuner.md
@@ -2,4 +2,6 @@
## Grid Search(遍历搜索)
-Grid Search 会穷举定义在搜索空间文件中的所有超参组合。 注意,搜索空间仅支持 `choice`, `quniform`, `qloguniform`。 `quniform` 和 `qloguniform` 中的 **数字 `q` 有不同的含义(与[搜索空间](../Tutorial/SearchSpaceSpec.md)说明不同)。 这里的意义是在 `low` 和 `high` 之间均匀取值的数量。
\ No newline at end of file
+Grid Search 会穷举定义在搜索空间文件中的所有超参组合。
+
+注意,搜索空间仅支持 `choice`, `quniform`, `randint`。
\ No newline at end of file
diff --git a/docs/zh_CN/Tutorial/ExperimentConfig.md b/docs/zh_CN/Tutorial/ExperimentConfig.md
index 700f8c00d2..21adc03b76 100644
--- a/docs/zh_CN/Tutorial/ExperimentConfig.md
+++ b/docs/zh_CN/Tutorial/ExperimentConfig.md
@@ -1,6 +1,6 @@
# Experiment(实验)配置参考
-创建 Experiment 时,需要给 nnictl 命令提供配置文件的路径。 配置文件是 YAML 格式,需要保证其格式正确。 本文介绍了配置文件的内容,并提供了一些示例和模板。
+创建 Experiment 所需要的配置文件。 配置文件的路径会传入 `nnictl` 命令。 配置文件的格式为 YAML。 本文介绍了配置文件的内容,并提供了一些示例和模板。
- [Experiment(实验)配置参考](#Experiment-config-reference)
- [模板](#Template)
@@ -549,6 +549,10 @@ machineList:
- **azureShare**
**azureShare** 是 Azure 文件存储的共享参数。
+
+ - **uploadRetryCount**
+
+ 如果上传文件至 Azure Storage 失败,NNI 会重试。此字段指定了重试的次数。
- **paiConfig**
diff --git a/examples/trials/efficientnet/.gitignore b/examples/trials/efficientnet/.gitignore
new file mode 100644
index 0000000000..d94725b323
--- /dev/null
+++ b/examples/trials/efficientnet/.gitignore
@@ -0,0 +1 @@
+EfficientNet-PyTorch
\ No newline at end of file
diff --git a/examples/trials/efficientnet/README.md b/examples/trials/efficientnet/README.md
new file mode 100644
index 0000000000..2d8f436594
--- /dev/null
+++ b/examples/trials/efficientnet/README.md
@@ -0,0 +1,19 @@
+# EfficientNet
+
+[EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](https://arxiv.org/abs/1905.11946)
+
+Provided here are: Search space and tuners for finding the best tuple (alpha, beta, gamma) for EfficientNet-B1 with grid search, as discussed in Section 3.3 in [paper](https://arxiv.org/abs/1905.11946).
+
+## Instructions
+
+1. Set your working directory here in this directory.
+2. Run `git clone https://github.com/ultmaster/EfficientNet-PyTorch` to clone this modified version of [EfficientNet-PyTorch](https://github.com/lukemelas/EfficientNet-PyTorch). The modifications were done to adhere to the original [Tensorflow version](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) as close as possible (including EMA, label smoothing and etc.); also added are the part which gets parameters from tuner and reports intermediate/final results. Clone it into `EfficientNet-PyTorch`; the files like `main.py`, `train_imagenet.sh` will appear inside, as specified in the configuration files.
+3. Run `nnictl create --config config_net.yml` to find the best EfficientNet-B1. Adjust the training service (PAI/local/remote), batch size in the config files according to the environment.
+
+For training on ImageNet, read `EfficientNet-PyTorch/train_imagenet.sh`. Download ImageNet beforehand and extract it adhering to [PyTorch format](https://pytorch.org/docs/stable/torchvision/datasets.html#imagenet) and then replace `/mnt/data/imagenet` in with the location of the ImageNet storage. This file should also be a good example to follow for mounting ImageNet into the container on OpenPAI.
+
+## Results
+
+The follow image is a screenshot, demonstrating the relationship between acc@1 and alpha, beta, gamma.
+
+![](assets/search_result.png)
\ No newline at end of file
diff --git a/examples/trials/efficientnet/README_zh_CN.md b/examples/trials/efficientnet/README_zh_CN.md
new file mode 100644
index 0000000000..2f4ac5e65f
--- /dev/null
+++ b/examples/trials/efficientnet/README_zh_CN.md
@@ -0,0 +1,19 @@
+# EfficientNet
+
+[EfficientNet: 重新思考卷积神经网络的模型尺度](https://arxiv.org/abs/1905.11946)
+
+这里提供了:使用遍历搜索为 EfficientNet-B1 找到最佳元组(alpha,beta,gamma)的搜索空间和 Tuner。参考[论文](https://arxiv.org/abs/1905.11946) 3.3。
+
+## 说明
+
+1. 设置此目录为当前目录。
+2. 运行 `git clone https://github.com/ultmaster/EfficientNet-PyTorch` 来 clone 修改过的 [EfficientNet-PyTorch](https://github.com/lukemelas/EfficientNet-PyTorch)。 修改尽可能接近原始的 [TensorFlow 版本](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) (包括 EMA,标记平滑度等等。);另外添加了代码从 Tuner 获取参数并回调中间和最终结果。 将其 clone 至 `EfficientNet-PyTorch`;`main.py`,`train_imagenet.sh` 等文件会在配置文件中指定的路径。
+3. 运行 `nnictl create --config config_net.yml` 来找到最好的 EfficientNet-B1。 根据环境来调整训练平台(OpenPAI、本机、远程),batch size。
+
+在 ImageNet 上的训练,可阅读 `EfficientNet-PyTorch/train_imagenet.sh`。 下载 ImageNet,并参考 [PyTorch 格式](https://pytorch.org/docs/stable/torchvision/datasets.html#imagenet) 来解压,然后将 `/mnt/data/imagenet` 替换为 ImageNet 的路径。 此文件也是如何将 ImageNet 挂载到 OpenPAI 容器的示例。
+
+## 结果
+
+下图展示了 acc@1 和 alpha、beta、gamma 之间的关系。
+
+![](assets/search_result.png)
\ No newline at end of file
diff --git a/examples/trials/efficientnet/assets/search_result.png b/examples/trials/efficientnet/assets/search_result.png
new file mode 100644
index 0000000000..91c4f6dac7
Binary files /dev/null and b/examples/trials/efficientnet/assets/search_result.png differ
diff --git a/examples/trials/efficientnet/config_net.yml b/examples/trials/efficientnet/config_net.yml
new file mode 100644
index 0000000000..3ae75ef46c
--- /dev/null
+++ b/examples/trials/efficientnet/config_net.yml
@@ -0,0 +1,28 @@
+authorName: unknown
+experimentName: example_efficient_net
+trialConcurrency: 8
+maxExecDuration: 48h
+maxTrialNum: 100
+trainingServicePlatform: pai
+searchSpacePath: search_net.json
+useAnnotation: false
+tuner:
+ codeDir: .
+ classFileName: tuner.py
+ className: FixedProductTuner
+ classArgs:
+ product: 2
+trial:
+ codeDir: EfficientNet-PyTorch
+ command: sh train_imagenet.sh
+ cpuNum: 4
+ memoryMB: 25000
+ shmMB: 25000
+ gpuNum: 1
+ virtualCluster: nni
+ image: msranni/nni:latest
+nniManagerIp:
+paiConfig:
+ userName:
+ passWord:
+ host:
diff --git a/examples/trials/efficientnet/search_net.json b/examples/trials/efficientnet/search_net.json
new file mode 100644
index 0000000000..bf45ba918d
--- /dev/null
+++ b/examples/trials/efficientnet/search_net.json
@@ -0,0 +1,14 @@
+{
+ "alpha": {
+ "_type": "quniform",
+ "_value": [1.0, 2.0, 0.1]
+ },
+ "beta": {
+ "_type": "quniform",
+ "_value": [1.0, 1.5, 0.1]
+ },
+ "gamma": {
+ "_type": "quniform",
+ "_value": [1.0, 1.5, 0.1]
+ }
+}
diff --git a/examples/trials/efficientnet/tuner.py b/examples/trials/efficientnet/tuner.py
new file mode 100644
index 0000000000..d091d40ac0
--- /dev/null
+++ b/examples/trials/efficientnet/tuner.py
@@ -0,0 +1,29 @@
+from nni.gridsearch_tuner.gridsearch_tuner import GridSearchTuner
+
+
+class FixedProductTuner(GridSearchTuner):
+ """
+ This tuner is essentially grid search, but it guarantees all the parameters with alpha * beta^2 * gamma^2 is
+ approximately `product`.
+ """
+
+ def __init__(self, product):
+ """
+ :param product: the constant provided, should be 2 in EfficientNet-B1
+ """
+ super().__init__()
+ self.product = product
+
+ def expand_parameters(self, para):
+ """
+ Filter out all qualified parameters
+ """
+ para = super().expand_parameters(para)
+ if all([key in para[0] for key in ["alpha", "beta", "gamma"]]): # if this is an interested set
+ ret_para = []
+ for p in para:
+ prod = p["alpha"] * (p["beta"] ** 2) * (p["gamma"] ** 2)
+ if abs(prod - self.product) < 0.1:
+ ret_para.append(p)
+ return ret_para
+ return para
diff --git a/examples/trials/mnist-nas/classic_mode/config_hpo.yml b/examples/trials/mnist-nas/classic_mode/config_hpo.yml
new file mode 100644
index 0000000000..3c04a62f9f
--- /dev/null
+++ b/examples/trials/mnist-nas/classic_mode/config_hpo.yml
@@ -0,0 +1,16 @@
+authorName: default
+experimentName: example_mnist
+trialConcurrency: 1
+maxExecDuration: 1h
+maxTrialNum: 10
+#choice: local, remote, pai
+trainingServicePlatform: local
+#choice: true, false
+useAnnotation: true
+tuner:
+ builtinTunerName: TPE
+trial:
+ command: python3 mnist.py --batch_num 200
+ codeDir: .
+ gpuNum: 0
+ nasMode: classic_mode
diff --git a/examples/trials/mnist-nas/config_ppo.yml b/examples/trials/mnist-nas/config_ppo.yml
new file mode 100644
index 0000000000..9be8e78570
--- /dev/null
+++ b/examples/trials/mnist-nas/config_ppo.yml
@@ -0,0 +1,19 @@
+authorName: NNI-example
+experimentName: example_mnist
+trialConcurrency: 1
+maxExecDuration: 100h
+maxTrialNum: 10000
+#choice: local, remote, pai
+trainingServicePlatform: local
+#choice: true, false
+useAnnotation: true
+tuner:
+ #choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner
+ #SMAC, PPO (SMAC and PPO should be installed through nnictl)
+ builtinTunerName: PPOTuner
+ classArgs:
+ optimize_mode: maximize
+trial:
+ command: python3 mnist.py
+ codeDir: .
+ gpuNum: 0
diff --git a/examples/trials/mnist-pytorch/mnist.py b/examples/trials/mnist-pytorch/mnist.py
index 5bc0becda9..91c68c12a9 100644
--- a/examples/trials/mnist-pytorch/mnist.py
+++ b/examples/trials/mnist-pytorch/mnist.py
@@ -108,17 +108,16 @@ def main(args):
train(args, model, device, train_loader, optimizer, epoch)
test_acc = test(args, model, device, test_loader)
- # report intermediate result
- nni.report_intermediate_result(test_acc)
- logger.debug('test accuracy %g', test_acc)
- logger.debug('Pipe send intermediate result done.')
-
- test_acc = test(args, model, device, test_loader)
- # report final result
- nni.report_final_result(test_acc)
- logger.debug('Final result is %g', test_acc)
- logger.debug('Send final result done.')
-
+ if epoch < args['epochs']:
+ # report intermediate result
+ nni.report_intermediate_result(test_acc)
+ logger.debug('test accuracy %g', test_acc)
+ logger.debug('Pipe send intermediate result done.')
+ else:
+ # report final result
+ nni.report_final_result(test_acc)
+ logger.debug('Final result is %g', test_acc)
+ logger.debug('Send final result done.')
def get_params():
diff --git a/examples/trials/nas_cifar10/README.md b/examples/trials/nas_cifar10/README.md
index 2f3b52a869..e6f03e0b58 100644
--- a/examples/trials/nas_cifar10/README.md
+++ b/examples/trials/nas_cifar10/README.md
@@ -2,7 +2,14 @@
===
Now we have an NAS example [NNI-NAS-Example](https://github.com/Crysple/NNI-NAS-Example) run in NNI using NAS interface from our contributors.
+
+We have included its trial code in this folder, and provided example config files to show how to use PPO tuner to tune the trial code.
+
+> Download data
+
+- `cd data && . download.sh`
+- `tar xzf cifar-10-python.tar.gz && mv cifar-batches cifar10`
Thanks our lovely contributors.
-And welcome more and more people to join us!
\ No newline at end of file
+And welcome more and more people to join us!
diff --git a/examples/trials/nas_cifar10/config_pai_ppo.yml b/examples/trials/nas_cifar10/config_pai_ppo.yml
new file mode 100644
index 0000000000..38156376bd
--- /dev/null
+++ b/examples/trials/nas_cifar10/config_pai_ppo.yml
@@ -0,0 +1,31 @@
+authorName: Unknown
+experimentName: enas_macro
+trialConcurrency: 20
+maxExecDuration: 2400h
+maxTrialNum: 20000
+#choice: local, remote
+trainingServicePlatform: pai
+#choice: true, false
+useAnnotation: true
+multiPhase: false
+versionCheck: false
+nniManagerIp: 0.0.0.0
+tuner:
+ builtinTunerName: PPOTuner
+ classArgs:
+ optimize_mode: maximize
+ trials_per_update: 60
+ epochs_per_update: 20
+ minibatch_size: 6
+trial:
+ command: sh ./macro_cifar10_pai.sh
+ codeDir: ./
+ gpuNum: 1
+ cpuNum: 1
+ memoryMB: 8196
+ image: msranni/nni:latest
+ virtualCluster: nni
+paiConfig:
+ userName: your_account
+ passWord: your_pwd
+ host: 0.0.0.0
diff --git a/examples/trials/nas_cifar10/config_ppo.yml b/examples/trials/nas_cifar10/config_ppo.yml
new file mode 100644
index 0000000000..8de1c5123f
--- /dev/null
+++ b/examples/trials/nas_cifar10/config_ppo.yml
@@ -0,0 +1,24 @@
+authorName: Unknown
+experimentName: enas_macro
+trialConcurrency: 4
+maxExecDuration: 2400h
+maxTrialNum: 20000
+#choice: local, remote
+trainingServicePlatform: local
+#choice: true, false
+useAnnotation: true
+multiPhase: false
+tuner:
+ builtinTunerName: PPOTuner
+ classArgs:
+ optimize_mode: maximize
+ trials_per_update: 60
+ epochs_per_update: 12
+ minibatch_size: 10
+ #could use the No. 0 gpu for this tuner
+ #if want to specify multiple gpus, here is an example of specifying three gpus: 0,1,2
+ gpuIndices: 0
+trial:
+ command: sh ./macro_cifar10.sh
+ codeDir: ./
+ gpuNum: 1
diff --git a/examples/trials/nas_cifar10/data/download.sh b/examples/trials/nas_cifar10/data/download.sh
new file mode 100755
index 0000000000..f00ac25724
--- /dev/null
+++ b/examples/trials/nas_cifar10/data/download.sh
@@ -0,0 +1 @@
+wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
diff --git a/examples/trials/nas_cifar10/macro_cifar10.sh b/examples/trials/nas_cifar10/macro_cifar10.sh
new file mode 100644
index 0000000000..863256d802
--- /dev/null
+++ b/examples/trials/nas_cifar10/macro_cifar10.sh
@@ -0,0 +1,31 @@
+#!/bin/bash
+set -e
+export PYTHONPATH="$(pwd)"
+
+python3 src/cifar10/nni_child_cifar10.py \
+ --data_format="NCHW" \
+ --search_for="macro" \
+ --reset_output_dir \
+ --data_path="data/cifar10" \
+ --output_dir="outputs" \
+ --train_data_size=45000 \
+ --batch_size=100 \
+ --num_epochs=8 \
+ --log_every=50 \
+ --eval_every_epochs=1 \
+ --child_use_aux_heads \
+ --child_num_layers=12 \
+ --child_out_filters=36 \
+ --child_l2_reg=0.0002 \
+ --child_num_branches=6 \
+ --child_num_cell_layers=5 \
+ --child_keep_prob=0.50 \
+ --child_drop_path_keep_prob=0.60 \
+ --child_lr_cosine \
+ --child_lr_max=0.05 \
+ --child_lr_min=0.001 \
+ --child_lr_T_0=10 \
+ --child_lr_T_mul=2 \
+ --child_mode="subgraph" \
+ "$@"
+
diff --git a/examples/trials/nas_cifar10/macro_cifar10_pai.sh b/examples/trials/nas_cifar10/macro_cifar10_pai.sh
new file mode 100644
index 0000000000..226955edc7
--- /dev/null
+++ b/examples/trials/nas_cifar10/macro_cifar10_pai.sh
@@ -0,0 +1,31 @@
+#!/bin/bash
+set -e
+export PYTHONPATH="$(pwd)"
+
+python3 src/cifar10/nni_child_cifar10.py \
+ --data_format="NCHW" \
+ --search_for="macro" \
+ --reset_output_dir \
+ --data_path="data/cifar10" \
+ --output_dir="outputs" \
+ --train_data_size=45000 \
+ --batch_size=100 \
+ --num_epochs=30 \
+ --log_every=50 \
+ --eval_every_epochs=1 \
+ --child_use_aux_heads \
+ --child_num_layers=12 \
+ --child_out_filters=36 \
+ --child_l2_reg=0.0002 \
+ --child_num_branches=6 \
+ --child_num_cell_layers=5 \
+ --child_keep_prob=0.50 \
+ --child_drop_path_keep_prob=0.60 \
+ --child_lr_cosine \
+ --child_lr_max=0.05 \
+ --child_lr_min=0.001 \
+ --child_lr_T_0=10 \
+ --child_lr_T_mul=2 \
+ --child_mode="subgraph" \
+ "$@"
+
diff --git a/examples/trials/nas_cifar10/src/__init__.py b/examples/trials/nas_cifar10/src/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/examples/trials/nas_cifar10/src/cifar10/__init__.py b/examples/trials/nas_cifar10/src/cifar10/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/examples/trials/nas_cifar10/src/cifar10/data_utils.py b/examples/trials/nas_cifar10/src/cifar10/data_utils.py
new file mode 100644
index 0000000000..b8a8c36339
--- /dev/null
+++ b/examples/trials/nas_cifar10/src/cifar10/data_utils.py
@@ -0,0 +1,74 @@
+import os
+import sys
+import pickle
+import numpy as np
+import tensorflow as tf
+
+
+def _read_data(data_path, train_files):
+ """Reads CIFAR-10 format data. Always returns NHWC format.
+
+ Returns:
+ images: np tensor of size [N, H, W, C]
+ labels: np tensor of size [N]
+ """
+ images, labels = [], []
+ for file_name in train_files:
+ print(file_name)
+ full_name = os.path.join(data_path, file_name)
+ with open(full_name, "rb") as finp:
+ data = pickle.load(finp, encoding='latin1')
+ batch_images = data["data"].astype(np.float32) / 255.0
+ batch_labels = np.array(data["labels"], dtype=np.int32)
+ images.append(batch_images)
+ labels.append(batch_labels)
+ images = np.concatenate(images, axis=0)
+ labels = np.concatenate(labels, axis=0)
+ images = np.reshape(images, [-1, 3, 32, 32])
+ images = np.transpose(images, [0, 2, 3, 1])
+
+ return images, labels
+
+
+def read_data(data_path, num_valids=5000):
+ print("-" * 80)
+ print("Reading data")
+
+ images, labels = {}, {}
+
+ train_files = [
+ "data_batch_1",
+ "data_batch_2",
+ "data_batch_3",
+ "data_batch_4",
+ "data_batch_5",
+ ]
+ test_file = [
+ "test_batch",
+ ]
+ images["train"], labels["train"] = _read_data(data_path, train_files)
+
+ if num_valids:
+ images["valid"] = images["train"][-num_valids:]
+ labels["valid"] = labels["train"][-num_valids:]
+
+ images["train"] = images["train"][:-num_valids]
+ labels["train"] = labels["train"][:-num_valids]
+ else:
+ images["valid"], labels["valid"] = None, None
+
+ images["test"], labels["test"] = _read_data(data_path, test_file)
+
+ print("Prepropcess: [subtract mean], [divide std]")
+ mean = np.mean(images["train"], axis=(0, 1, 2), keepdims=True)
+ std = np.std(images["train"], axis=(0, 1, 2), keepdims=True)
+
+ print("mean: {}".format(np.reshape(mean * 255.0, [-1])))
+ print("std: {}".format(np.reshape(std * 255.0, [-1])))
+
+ images["train"] = (images["train"] - mean) / std
+ if num_valids:
+ images["valid"] = (images["valid"] - mean) / std
+ images["test"] = (images["test"] - mean) / std
+
+ return images, labels
diff --git a/examples/trials/nas_cifar10/src/cifar10/general_child.py b/examples/trials/nas_cifar10/src/cifar10/general_child.py
new file mode 100644
index 0000000000..4e80dc340e
--- /dev/null
+++ b/examples/trials/nas_cifar10/src/cifar10/general_child.py
@@ -0,0 +1,423 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import numpy as np
+import tensorflow as tf
+from src.common_ops import create_weight, batch_norm, batch_norm_with_mask, global_avg_pool, conv_op, pool_op
+from src.utils import count_model_params, get_train_ops, get_C, get_strides
+from src.cifar10.models import Model
+
+
+class GeneralChild(Model):
+ def __init__(self,
+ images,
+ labels,
+ cutout_size=None,
+ fixed_arc=None,
+ out_filters_scale=1,
+ num_layers=2,
+ num_branches=6,
+ out_filters=24,
+ keep_prob=1.0,
+ batch_size=32,
+ clip_mode=None,
+ grad_bound=None,
+ l2_reg=1e-4,
+ lr_init=0.1,
+ lr_dec_start=0,
+ lr_dec_every=10000,
+ lr_dec_rate=0.1,
+ lr_cosine=False,
+ lr_max=None,
+ lr_min=None,
+ lr_T_0=None,
+ lr_T_mul=None,
+ optim_algo=None,
+ sync_replicas=False,
+ num_aggregate=None,
+ num_replicas=None,
+ data_format="NHWC",
+ name="child",
+ mode="subgraph",
+ *args,
+ **kwargs
+ ):
+
+ super(self.__class__, self).__init__(
+ images,
+ labels,
+ cutout_size=cutout_size,
+ batch_size=batch_size,
+ clip_mode=clip_mode,
+ grad_bound=grad_bound,
+ l2_reg=l2_reg,
+ lr_init=lr_init,
+ lr_dec_start=lr_dec_start,
+ lr_dec_every=lr_dec_every,
+ lr_dec_rate=lr_dec_rate,
+ keep_prob=keep_prob,
+ optim_algo=optim_algo,
+ sync_replicas=sync_replicas,
+ num_aggregate=num_aggregate,
+ num_replicas=num_replicas,
+ data_format=data_format,
+ name=name)
+
+ self.lr_cosine = lr_cosine
+ self.lr_max = lr_max
+ self.lr_min = lr_min
+ self.lr_T_0 = lr_T_0
+ self.lr_T_mul = lr_T_mul
+ self.out_filters = out_filters * out_filters_scale
+ self.num_layers = num_layers
+ self.mode = mode
+
+ self.num_branches = num_branches
+ self.fixed_arc = fixed_arc
+ self.out_filters_scale = out_filters_scale
+
+ pool_distance = self.num_layers // 3
+ self.pool_layers = [pool_distance - 1, 2 * pool_distance - 1]
+
+
+
+ def _factorized_reduction(self, x, out_filters, stride, is_training):
+ """Reduces the shape of x without information loss due to striding."""
+ assert out_filters % 2 == 0, (
+ "Need even number of filters when using this factorized reduction.")
+ if stride == 1:
+ with tf.variable_scope("path_conv"):
+ inp_c = get_C(x, self.data_format)
+ w = create_weight("w", [1, 1, inp_c, out_filters])
+ x = tf.nn.conv2d(x, w, [1, 1, 1, 1], "SAME",
+ data_format=self.data_format)
+ x = batch_norm(x, is_training, data_format=self.data_format)
+ return x
+
+ stride_spec = get_strides(stride, self.data_format)
+ # Skip path 1
+ path1 = tf.nn.avg_pool(
+ x, [1, 1, 1, 1], stride_spec, "VALID", data_format=self.data_format)
+ with tf.variable_scope("path1_conv"):
+ inp_c = get_C(path1, self.data_format)
+ w = create_weight("w", [1, 1, inp_c, out_filters // 2])
+ path1 = tf.nn.conv2d(path1, w, [1, 1, 1, 1], "SAME",
+ data_format=self.data_format)
+
+ # Skip path 2
+ # First pad with 0"s on the right and bottom, then shift the filter to
+ # include those 0"s that were added.
+ if self.data_format == "NHWC":
+ pad_arr = [[0, 0], [0, 1], [0, 1], [0, 0]]
+ path2 = tf.pad(x, pad_arr)[:, 1:, 1:, :]
+ concat_axis = 3
+ else:
+ pad_arr = [[0, 0], [0, 0], [0, 1], [0, 1]]
+ path2 = tf.pad(x, pad_arr)[:, :, 1:, 1:]
+ concat_axis = 1
+
+ path2 = tf.nn.avg_pool(
+ path2, [1, 1, 1, 1], stride_spec, "VALID", data_format=self.data_format)
+ with tf.variable_scope("path2_conv"):
+ inp_c = get_C(path2, self.data_format)
+ w = create_weight("w", [1, 1, inp_c, out_filters // 2])
+ path2 = tf.nn.conv2d(path2, w, [1, 1, 1, 1], "SAME",
+ data_format=self.data_format)
+
+ # Concat and apply BN
+ final_path = tf.concat(values=[path1, path2], axis=concat_axis)
+ final_path = batch_norm(final_path, is_training,
+ data_format=self.data_format)
+
+ return final_path
+
+ def _model(self, images, is_training, reuse=False):
+ '''Build model'''
+ with tf.variable_scope(self.name, reuse=reuse):
+ layers = []
+
+ out_filters = self.out_filters
+ with tf.variable_scope("stem_conv"):
+ w = create_weight("w", [3, 3, 3, out_filters])
+ x = tf.nn.conv2d(
+ images, w, [1, 1, 1, 1], "SAME", data_format=self.data_format)
+ x = batch_norm(x, is_training, data_format=self.data_format)
+ layers.append(x)
+
+ def add_fixed_pooling_layer(layer_id, layers, out_filters, is_training):
+ '''Add a fixed pooling layer every four layers'''
+ out_filters *= 2
+ with tf.variable_scope("pool_at_{0}".format(layer_id)):
+ pooled_layers = []
+ for i, layer in enumerate(layers):
+ with tf.variable_scope("from_{0}".format(i)):
+ x = self._factorized_reduction(
+ layer, out_filters, 2, is_training)
+ pooled_layers.append(x)
+ return pooled_layers, out_filters
+
+ def post_process_out(out, optional_inputs):
+ '''Form skip connection and perform batch norm'''
+ with tf.variable_scope("skip"):
+ inputs = layers[-1]
+ if self.data_format == "NHWC":
+ inp_h = inputs.get_shape()[1].value
+ inp_w = inputs.get_shape()[2].value
+ inp_c = inputs.get_shape()[3].value
+ out.set_shape([None, inp_h, inp_w, out_filters])
+ elif self.data_format == "NCHW":
+ inp_c = inputs.get_shape()[1].value
+ inp_h = inputs.get_shape()[2].value
+ inp_w = inputs.get_shape()[3].value
+ out.set_shape([None, out_filters, inp_h, inp_w])
+ optional_inputs.append(out)
+ pout = tf.add_n(optional_inputs)
+ out = batch_norm(pout, is_training,
+ data_format=self.data_format)
+ layers.append(out)
+ return out
+
+ global layer_id
+ layer_id = -1
+
+ def get_layer_id():
+ global layer_id
+ layer_id += 1
+ return 'layer_' + str(layer_id)
+
+ def conv3(inputs):
+ # res_layers is pre_layers that are chosen to form skip connection
+ # layers[-1] is always the latest input
+ with tf.variable_scope(get_layer_id()):
+ with tf.variable_scope('branch_0'):
+ out = conv_op(
+ inputs[0][0], 3, is_training, out_filters, out_filters, self.data_format, start_idx=None)
+ out = post_process_out(out, inputs[1])
+ return out
+
+ def conv3_sep(inputs):
+ with tf.variable_scope(get_layer_id()):
+ with tf.variable_scope('branch_1'):
+ out = conv_op(
+ inputs[0][0], 3, is_training, out_filters, out_filters, self.data_format, start_idx=None, separable=True)
+ out = post_process_out(out, inputs[1])
+ return out
+
+ def conv5(inputs):
+ with tf.variable_scope(get_layer_id()):
+ with tf.variable_scope('branch_2'):
+ out = conv_op(
+ inputs[0][0], 5, is_training, out_filters, out_filters, self.data_format, start_idx=None)
+ out = post_process_out(out, inputs[1])
+ return out
+
+ def conv5_sep(inputs):
+ with tf.variable_scope(get_layer_id()):
+ with tf.variable_scope('branch_3'):
+ out = conv_op(
+ inputs[0][0], 5, is_training, out_filters, out_filters, self.data_format, start_idx=None, separable=True)
+ out = post_process_out(out, inputs[1])
+ return out
+
+ def avg_pool(inputs):
+ with tf.variable_scope(get_layer_id()):
+ with tf.variable_scope('branch_4'):
+ out = pool_op(
+ inputs[0][0], is_training, out_filters, out_filters, "avg", self.data_format, start_idx=None)
+ out = post_process_out(out, inputs[1])
+ return out
+
+ def max_pool(inputs):
+ with tf.variable_scope(get_layer_id()):
+ with tf.variable_scope('branch_5'):
+ out = pool_op(
+ inputs[0][0], is_training, out_filters, out_filters, "max", self.data_format, start_idx=None)
+ out = post_process_out(out, inputs[1])
+ return out
+
+ """@nni.mutable_layers(
+ {
+ layer_choice: [conv3(), conv3_sep(), conv5(), conv5_sep(), avg_pool(), max_pool()],
+ fixed_inputs:[x],
+ layer_output: layer_0_out
+ },
+ {
+ layer_choice: [conv3(), conv3_sep(), conv5(), conv5_sep(), avg_pool(), max_pool()],
+ fixed_inputs:[layer_0_out],
+ optional_inputs: [layer_0_out],
+ optional_input_size: [0, 1],
+ layer_output: layer_1_out
+ },
+ {
+ layer_choice: [conv3(), conv3_sep(), conv5(), conv5_sep(), avg_pool(), max_pool()],
+ fixed_inputs:[layer_1_out],
+ optional_inputs: [layer_0_out, layer_1_out],
+ optional_input_size: [0, 1],
+ layer_output: layer_2_out
+ },
+ {
+ layer_choice: [conv3(), conv3_sep(), conv5(), conv5_sep(), avg_pool(), max_pool()],
+ fixed_inputs:[layer_2_out],
+ optional_inputs: [layer_0_out, layer_1_out, layer_2_out],
+ optional_input_size: [0, 1],
+ layer_output: layer_3_out
+ }
+ )"""
+ layers, out_filters = add_fixed_pooling_layer(
+ 3, layers, out_filters, is_training)
+ layer_0_out, layer_1_out, layer_2_out, layer_3_out = layers[-4:]
+ """@nni.mutable_layers(
+ {
+ layer_choice: [conv3(), conv3_sep(), conv5(), conv5_sep(), avg_pool(), max_pool()],
+ fixed_inputs: [layer_3_out],
+ optional_inputs: [layer_0_out, layer_1_out, layer_2_out, layer_3_out],
+ optional_input_size: [0, 1],
+ layer_output: layer_4_out
+ },
+ {
+ layer_choice: [conv3(), conv3_sep(), conv5(), conv5_sep(), avg_pool(), max_pool()],
+ fixed_inputs: [layer_4_out],
+ optional_inputs: [layer_0_out, layer_1_out, layer_2_out, layer_3_out, layer_4_out],
+ optional_input_size: [0, 1],
+ layer_output: layer_5_out
+ },
+ {
+ layer_choice: [conv3(), conv3_sep(), conv5(), conv5_sep(), avg_pool(), max_pool()],
+ fixed_inputs: [layer_5_out],
+ optional_inputs: [layer_0_out, layer_1_out, layer_2_out, layer_3_out, layer_4_out, layer_5_out],
+ optional_input_size: [0, 1],
+ layer_output: layer_6_out
+ },
+ {
+ layer_choice: [conv3(), conv3_sep(), conv5(), conv5_sep(), avg_pool(), max_pool()],
+ fixed_inputs: [layer_6_out],
+ optional_inputs: [layer_0_out, layer_1_out, layer_2_out, layer_3_out, layer_4_out, layer_5_out, layer_6_out],
+ optional_input_size: [0, 1],
+ layer_output: layer_7_out
+ }
+ )"""
+ layers, out_filters = add_fixed_pooling_layer(
+ 7, layers, out_filters, is_training)
+ layer_0_out, layer_1_out, layer_2_out, layer_3_out, layer_4_out, layer_5_out, layer_6_out, layer_7_out = layers[
+ -8:]
+ """@nni.mutable_layers(
+ {
+ layer_choice: [conv3(), conv3_sep(), conv5(), conv5_sep(), avg_pool(), max_pool()],
+ fixed_inputs: [layer_7_out],
+ optional_inputs: [layer_0_out, layer_1_out, layer_2_out, layer_3_out, layer_4_out, layer_5_out, layer_6_out, layer_7_out],
+ optional_input_size: [0, 1],
+ layer_output: layer_8_out
+ },
+ {
+ layer_choice: [conv3(), conv3_sep(), conv5(), conv5_sep(), avg_pool(), max_pool()],
+ fixed_inputs: [layer_8_out],
+ optional_inputs: [layer_0_out, layer_1_out, layer_2_out, layer_3_out, layer_4_out, layer_5_out, layer_6_out, layer_7_out, layer_8_out],
+ optional_input_size: [0, 1],
+ layer_output: layer_9_out
+ },
+ {
+ layer_choice: [conv3(), conv3_sep(), conv5(), conv5_sep(), avg_pool(), max_pool()],
+ fixed_inputs: [layer_9_out],
+ optional_inputs: [layer_0_out, layer_1_out, layer_2_out, layer_3_out, layer_4_out, layer_5_out, layer_6_out, layer_7_out, layer_8_out, layer_9_out],
+ optional_input_size: [0, 1],
+ layer_output: layer_10_out
+ },
+ {
+ layer_choice: [conv3(), conv3_sep(), conv5(), conv5_sep(), avg_pool(), max_pool()],
+ fixed_inputs:[layer_10_out],
+ optional_inputs: [layer_0_out, layer_1_out, layer_2_out, layer_3_out, layer_4_out, layer_5_out, layer_6_out, layer_7_out, layer_8_out, layer_9_out, layer_10_out],
+ optional_input_size: [0, 1],
+ layer_output: layer_11_out
+ }
+ )"""
+
+ x = global_avg_pool(layer_11_out, data_format=self.data_format)
+ if is_training:
+ x = tf.nn.dropout(x, self.keep_prob)
+ with tf.variable_scope("fc"):
+ if self.data_format == "NHWC":
+ inp_c = x.get_shape()[3].value
+ elif self.data_format == "NCHW":
+ inp_c = x.get_shape()[1].value
+ else:
+ raise ValueError(
+ "Unknown data_format {0}".format(self.data_format))
+ w = create_weight("w", [inp_c, 10])
+ x = tf.matmul(x, w)
+ return x
+
+
+ # override
+ def _build_train(self):
+ print("-" * 80)
+ print("Build train graph")
+ logits = self._model(self.x_train, is_training=True)
+ log_probs = tf.nn.sparse_softmax_cross_entropy_with_logits(
+ logits=logits, labels=self.y_train)
+ self.loss = tf.reduce_mean(log_probs)
+
+ self.train_preds = tf.argmax(logits, axis=1)
+ self.train_preds = tf.to_int32(self.train_preds)
+ self.train_acc = tf.equal(self.train_preds, self.y_train)
+ self.train_acc = tf.to_int32(self.train_acc)
+ self.train_acc = tf.reduce_sum(self.train_acc)
+
+ tf_variables = [var
+ for var in tf.trainable_variables() if var.name.startswith(self.name)]
+ self.num_vars = count_model_params(tf_variables)
+ print("Model has {} params".format(self.num_vars))
+
+ self.global_step = tf.Variable(
+ 0, dtype=tf.int32, trainable=False, name="global_step")
+
+ self.train_op, self.lr, self.grad_norm, self.optimizer = get_train_ops(
+ self.loss,
+ tf_variables,
+ self.global_step,
+ clip_mode=self.clip_mode,
+ grad_bound=self.grad_bound,
+ l2_reg=self.l2_reg,
+ lr_init=self.lr_init,
+ lr_dec_start=self.lr_dec_start,
+ lr_dec_every=self.lr_dec_every,
+ lr_dec_rate=self.lr_dec_rate,
+ lr_cosine=self.lr_cosine,
+ lr_max=self.lr_max,
+ lr_min=self.lr_min,
+ lr_T_0=self.lr_T_0,
+ lr_T_mul=self.lr_T_mul,
+ num_train_batches=self.num_train_batches,
+ optim_algo=self.optim_algo,
+ sync_replicas=False,
+ num_aggregate=self.num_aggregate,
+ num_replicas=self.num_replicas)
+
+ # override
+ def _build_valid(self):
+ if self.x_valid is not None:
+ print("-" * 80)
+ print("Build valid graph")
+ logits = self._model(self.x_valid, False, reuse=True)
+ self.valid_preds = tf.argmax(logits, axis=1)
+ self.valid_preds = tf.to_int32(self.valid_preds)
+ self.valid_acc = tf.equal(self.valid_preds, self.y_valid)
+ self.valid_acc = tf.to_int32(self.valid_acc)
+ self.valid_acc = tf.reduce_sum(self.valid_acc)
+
+ # override
+ def _build_test(self):
+ print("-" * 80)
+ print("Build test graph")
+ logits = self._model(self.x_test, False, reuse=True)
+ self.test_preds = tf.argmax(logits, axis=1)
+ self.test_preds = tf.to_int32(self.test_preds)
+ self.test_acc = tf.equal(self.test_preds, self.y_test)
+ self.test_acc = tf.to_int32(self.test_acc)
+ self.test_acc = tf.reduce_sum(self.test_acc)
+
+
+ def build_model(self):
+
+ self._build_train()
+ self._build_valid()
+ self._build_test()
diff --git a/examples/trials/nas_cifar10/src/cifar10/models.py b/examples/trials/nas_cifar10/src/cifar10/models.py
new file mode 100644
index 0000000000..089fe846a6
--- /dev/null
+++ b/examples/trials/nas_cifar10/src/cifar10/models.py
@@ -0,0 +1,196 @@
+import os
+import sys
+
+import numpy as np
+import tensorflow as tf
+
+
+class Model(object):
+ def __init__(self,
+ images,
+ labels,
+ cutout_size=None,
+ batch_size=32,
+ eval_batch_size=100,
+ clip_mode=None,
+ grad_bound=None,
+ l2_reg=1e-4,
+ lr_init=0.1,
+ lr_dec_start=0,
+ lr_dec_every=100,
+ lr_dec_rate=0.1,
+ keep_prob=1.0,
+ optim_algo=None,
+ sync_replicas=False,
+ num_aggregate=None,
+ num_replicas=None,
+ data_format="NHWC",
+ name="generic_model",
+ seed=None,
+ ):
+ """
+ Args:
+ lr_dec_every: number of epochs to decay
+ """
+ print("-" * 80)
+ print("Build model {}".format(name))
+
+ self.cutout_size = cutout_size
+ self.batch_size = batch_size
+ self.eval_batch_size = eval_batch_size
+ self.clip_mode = clip_mode
+ self.grad_bound = grad_bound
+ self.l2_reg = l2_reg
+ self.lr_init = lr_init
+ self.lr_dec_start = lr_dec_start
+ self.lr_dec_rate = lr_dec_rate
+ self.keep_prob = keep_prob
+ self.optim_algo = optim_algo
+ self.sync_replicas = sync_replicas
+ self.num_aggregate = num_aggregate
+ self.num_replicas = num_replicas
+ self.data_format = data_format
+ self.name = name
+ self.seed = seed
+
+ self.global_step = None
+ self.valid_acc = None
+ self.test_acc = None
+ print("Build data ops")
+ with tf.device("/cpu:0"):
+ # training data
+ self.num_train_examples = np.shape(images["train"])[0]
+
+ self.num_train_batches = (
+ self.num_train_examples + self.batch_size - 1) // self.batch_size
+ x_train, y_train = tf.train.shuffle_batch(
+ [images["train"], labels["train"]],
+ batch_size=self.batch_size,
+ capacity=50000,
+ enqueue_many=True,
+ min_after_dequeue=0,
+ num_threads=16,
+ seed=self.seed,
+ allow_smaller_final_batch=True,
+ )
+ self.lr_dec_every = lr_dec_every * self.num_train_batches
+
+ def _pre_process(x):
+ x = tf.pad(x, [[4, 4], [4, 4], [0, 0]])
+ x = tf.random_crop(x, [32, 32, 3], seed=self.seed)
+ x = tf.image.random_flip_left_right(x, seed=self.seed)
+ if self.cutout_size is not None:
+ mask = tf.ones(
+ [self.cutout_size, self.cutout_size], dtype=tf.int32)
+ start = tf.random_uniform(
+ [2], minval=0, maxval=32, dtype=tf.int32)
+ mask = tf.pad(mask, [[self.cutout_size + start[0], 32 - start[0]],
+ [self.cutout_size + start[1], 32 - start[1]]])
+ mask = mask[self.cutout_size: self.cutout_size + 32,
+ self.cutout_size: self.cutout_size + 32]
+ mask = tf.reshape(mask, [32, 32, 1])
+ mask = tf.tile(mask, [1, 1, 3])
+ x = tf.where(tf.equal(mask, 0), x=x, y=tf.zeros_like(x))
+ if self.data_format == "NCHW":
+ x = tf.transpose(x, [2, 0, 1])
+
+ return x
+ self.x_train = tf.map_fn(_pre_process, x_train, back_prop=False)
+ self.y_train = y_train
+
+ # valid data
+ self.x_valid, self.y_valid = None, None
+ if images["valid"] is not None:
+ images["valid_original"] = np.copy(images["valid"])
+ labels["valid_original"] = np.copy(labels["valid"])
+ if self.data_format == "NCHW":
+ images["valid"] = tf.transpose(
+ images["valid"], [0, 3, 1, 2])
+ self.num_valid_examples = np.shape(images["valid"])[0]
+ self.num_valid_batches = (
+ (self.num_valid_examples + self.eval_batch_size - 1)
+ // self.eval_batch_size)
+ self.x_valid, self.y_valid = tf.train.batch(
+ [images["valid"], labels["valid"]],
+ batch_size=self.eval_batch_size,
+ capacity=5000,
+ enqueue_many=True,
+ num_threads=1,
+ allow_smaller_final_batch=True,
+ )
+
+ # test data
+ if self.data_format == "NCHW":
+ images["test"] = tf.transpose(images["test"], [0, 3, 1, 2])
+ self.num_test_examples = np.shape(images["test"])[0]
+ self.num_test_batches = (
+ (self.num_test_examples + self.eval_batch_size - 1)
+ // self.eval_batch_size)
+ self.x_test, self.y_test = tf.train.batch(
+ [images["test"], labels["test"]],
+ batch_size=self.eval_batch_size,
+ capacity=10000,
+ enqueue_many=True,
+ num_threads=1,
+ allow_smaller_final_batch=True,
+ )
+
+ # cache images and labels
+ self.images = images
+ self.labels = labels
+
+ def eval_once(self, sess, eval_set, child_model, verbose=False):
+ """Expects self.acc and self.global_step to be defined.
+
+ Args:
+ sess: tf.Session() or one of its wrap arounds.
+ feed_dict: can be used to give more information to sess.run().
+ eval_set: "valid" or "test"
+ """
+
+ assert self.global_step is not None
+ global_step = sess.run(self.global_step)
+ print("Eval at {}".format(global_step))
+
+ if eval_set == "valid":
+ assert self.x_valid is not None
+ assert self.valid_acc is not None
+ num_examples = self.num_valid_examples
+ num_batches = self.num_valid_batches
+ acc_op = self.valid_acc
+ elif eval_set == "test":
+ assert self.test_acc is not None
+ num_examples = self.num_test_examples
+ num_batches = self.num_test_batches
+ acc_op = self.test_acc
+ else:
+ raise NotImplementedError("Unknown eval_set '{}'".format(eval_set))
+
+ total_acc = 0
+ total_exp = 0
+
+ for batch_id in range(num_batches):
+ acc = sess.run(acc_op)
+
+ total_acc += acc
+ total_exp += self.eval_batch_size
+ if verbose:
+ sys.stdout.write(
+ "\r{:<5d}/{:>5d}".format(total_acc, total_exp))
+ if verbose:
+ print("")
+ print("{}_accuracy: {:<6.4f}".format(
+ eval_set, float(total_acc) / total_exp))
+ return float(total_acc) / total_exp
+
+ def _model(self, images, is_training, reuse=None):
+ raise NotImplementedError("Abstract method")
+
+ def _build_train(self):
+ raise NotImplementedError("Abstract method")
+
+ def _build_valid(self):
+ raise NotImplementedError("Abstract method")
+
+ def _build_test(self):
+ raise NotImplementedError("Abstract method")
diff --git a/examples/trials/nas_cifar10/src/cifar10/nni_child_cifar10.py b/examples/trials/nas_cifar10/src/cifar10/nni_child_cifar10.py
new file mode 100644
index 0000000000..5481ba7b07
--- /dev/null
+++ b/examples/trials/nas_cifar10/src/cifar10/nni_child_cifar10.py
@@ -0,0 +1,162 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import os
+import shutil
+import logging
+import tensorflow as tf
+from src.cifar10.data_utils import read_data
+from src.cifar10.general_child import GeneralChild
+import src.cifar10_flags
+from src.cifar10_flags import FLAGS
+
+
+def build_logger(log_name):
+ logger = logging.getLogger(log_name)
+ logger.setLevel(logging.DEBUG)
+ fh = logging.FileHandler(log_name+'.log')
+ fh.setLevel(logging.DEBUG)
+ logger.addHandler(fh)
+ return logger
+
+
+logger = build_logger("nni_child_cifar10")
+
+
+def build_trial(images, labels, ChildClass):
+ '''Build child class'''
+ child_model = ChildClass(
+ images,
+ labels,
+ use_aux_heads=FLAGS.child_use_aux_heads,
+ cutout_size=FLAGS.child_cutout_size,
+ num_layers=FLAGS.child_num_layers,
+ num_cells=FLAGS.child_num_cells,
+ num_branches=FLAGS.child_num_branches,
+ fixed_arc=FLAGS.child_fixed_arc,
+ out_filters_scale=FLAGS.child_out_filters_scale,
+ out_filters=FLAGS.child_out_filters,
+ keep_prob=FLAGS.child_keep_prob,
+ drop_path_keep_prob=FLAGS.child_drop_path_keep_prob,
+ num_epochs=FLAGS.num_epochs,
+ l2_reg=FLAGS.child_l2_reg,
+ data_format=FLAGS.data_format,
+ batch_size=FLAGS.batch_size,
+ clip_mode="norm",
+ grad_bound=FLAGS.child_grad_bound,
+ lr_init=FLAGS.child_lr,
+ lr_dec_every=FLAGS.child_lr_dec_every,
+ lr_dec_rate=FLAGS.child_lr_dec_rate,
+ lr_cosine=FLAGS.child_lr_cosine,
+ lr_max=FLAGS.child_lr_max,
+ lr_min=FLAGS.child_lr_min,
+ lr_T_0=FLAGS.child_lr_T_0,
+ lr_T_mul=FLAGS.child_lr_T_mul,
+ optim_algo="momentum",
+ sync_replicas=FLAGS.child_sync_replicas,
+ num_aggregate=FLAGS.child_num_aggregate,
+ num_replicas=FLAGS.child_num_replicas
+ )
+
+ return child_model
+
+
+def get_child_ops(child_model):
+ '''Assemble child op to a dict'''
+ child_ops = {
+ "global_step": child_model.global_step,
+ "loss": child_model.loss,
+ "train_op": child_model.train_op,
+ "lr": child_model.lr,
+ "grad_norm": child_model.grad_norm,
+ "train_acc": child_model.train_acc,
+ "optimizer": child_model.optimizer,
+ "num_train_batches": child_model.num_train_batches,
+ "eval_every": child_model.num_train_batches * FLAGS.eval_every_epochs,
+ "eval_func": child_model.eval_once,
+ }
+ return child_ops
+
+
+class NASTrial():
+
+ def __init__(self):
+ images, labels = read_data(FLAGS.data_path, num_valids=0)
+
+ self.output_dir = os.path.join(os.getenv('NNI_OUTPUT_DIR'), '../..')
+ self.file_path = os.path.join(
+ self.output_dir, 'trainable_variable.txt')
+
+ self.graph = tf.Graph()
+ with self.graph.as_default():
+ self.child_model = build_trial(images, labels, GeneralChild)
+
+ self.total_data = {}
+
+ self.child_model.build_model()
+ self.child_ops = get_child_ops(self.child_model)
+ config = tf.ConfigProto(
+ intra_op_parallelism_threads=0,
+ inter_op_parallelism_threads=0,
+ allow_soft_placement=True)
+
+ self.sess = tf.train.SingularMonitoredSession(config=config)
+
+ logger.debug('initlize NASTrial done.')
+
+ def run_one_step(self):
+ '''Run this model on a batch of data'''
+ run_ops = [
+ self.child_ops["loss"],
+ self.child_ops["lr"],
+ self.child_ops["grad_norm"],
+ self.child_ops["train_acc"],
+ self.child_ops["train_op"],
+ ]
+ loss, lr, gn, tr_acc, _ = self.sess.run(run_ops)
+ global_step = self.sess.run(self.child_ops["global_step"])
+ log_string = ""
+ log_string += "ch_step={:<6d}".format(global_step)
+ log_string += " loss={:<8.6f}".format(loss)
+ log_string += " lr={:<8.4f}".format(lr)
+ log_string += " |g|={:<8.4f}".format(gn)
+ log_string += " tr_acc={:<3d}/{:>3d}".format(tr_acc, FLAGS.batch_size)
+ if int(global_step) % FLAGS.log_every == 0:
+ logger.debug(log_string)
+ return loss, global_step
+
+ def run(self):
+ '''Run this model according to the `epoch` set in FALGS'''
+ max_acc = 0
+ while True:
+ _, global_step = self.run_one_step()
+ if global_step % self.child_ops['num_train_batches'] == 0:
+ acc = self.child_ops["eval_func"](
+ self.sess, "test", self.child_model)
+ max_acc = max(max_acc, acc)
+ '''@nni.report_intermediate_result(acc)'''
+ if global_step / self.child_ops['num_train_batches'] >= FLAGS.num_epochs:
+ '''@nni.report_final_result(max_acc)'''
+ break
+
+
+def main(_):
+ logger.debug("-" * 80)
+
+ if not os.path.isdir(FLAGS.output_dir):
+ logger.debug(
+ "Path {} does not exist. Creating.".format(FLAGS.output_dir))
+ os.makedirs(FLAGS.output_dir)
+ elif FLAGS.reset_output_dir:
+ logger.debug(
+ "Path {} exists. Remove and remake.".format(FLAGS.output_dir))
+ shutil.rmtree(FLAGS.output_dir)
+ os.makedirs(FLAGS.output_dir)
+ logger.debug("-" * 80)
+ trial = NASTrial()
+
+ trial.run()
+
+
+if __name__ == "__main__":
+ tf.app.run()
diff --git a/examples/trials/nas_cifar10/src/cifar10_flags.py b/examples/trials/nas_cifar10/src/cifar10_flags.py
new file mode 100644
index 0000000000..2374f76b90
--- /dev/null
+++ b/examples/trials/nas_cifar10/src/cifar10_flags.py
@@ -0,0 +1,45 @@
+import tensorflow as tf
+from src.utils import DEFINE_boolean
+from src.utils import DEFINE_float
+from src.utils import DEFINE_integer
+from src.utils import DEFINE_string
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+
+DEFINE_boolean("reset_output_dir", False, "Delete output_dir if exists.")
+DEFINE_string("data_path", "", "")
+DEFINE_string("output_dir", "", "")
+DEFINE_string("data_format", "NHWC", "'NHWC' or 'NCWH'")
+DEFINE_string("search_for", None, "Must be [macro|micro]")
+DEFINE_integer("train_data_size", 45000, "")
+DEFINE_integer("batch_size", 32, "")
+
+DEFINE_integer("num_epochs", 300, "")
+DEFINE_integer("child_lr_dec_every", 100, "")
+DEFINE_integer("child_num_layers", 5, "")
+DEFINE_integer("child_num_cells", 5, "")
+DEFINE_integer("child_filter_size", 5, "")
+DEFINE_integer("child_out_filters", 48, "")
+DEFINE_integer("child_out_filters_scale", 1, "")
+DEFINE_integer("child_num_branches", 4, "")
+DEFINE_integer("child_num_aggregate", None, "")
+DEFINE_integer("child_num_replicas", 1, "")
+DEFINE_integer("child_block_size", 3, "")
+DEFINE_integer("child_lr_T_0", None, "for lr schedule")
+DEFINE_integer("child_lr_T_mul", None, "for lr schedule")
+DEFINE_integer("child_cutout_size", None, "CutOut size")
+DEFINE_float("child_grad_bound", 5.0, "Gradient clipping")
+DEFINE_float("child_lr", 0.1, "")
+DEFINE_float("child_lr_dec_rate", 0.1, "")
+DEFINE_float("child_keep_prob", 0.5, "")
+DEFINE_float("child_drop_path_keep_prob", 1.0, "minimum drop_path_keep_prob")
+DEFINE_float("child_l2_reg", 1e-4, "")
+DEFINE_float("child_lr_max", None, "for lr schedule")
+DEFINE_float("child_lr_min", None, "for lr schedule")
+DEFINE_string("child_skip_pattern", None, "Must be ['dense', None]")
+DEFINE_string("child_fixed_arc", None, "")
+DEFINE_boolean("child_use_aux_heads", False, "Should we use an aux head")
+DEFINE_boolean("child_sync_replicas", False, "To sync or not to sync.")
+DEFINE_boolean("child_lr_cosine", False, "Use cosine lr schedule")
+DEFINE_integer("log_every", 50, "How many steps to log")
+DEFINE_integer("eval_every_epochs", 1, "How many epochs to eval")
diff --git a/examples/trials/nas_cifar10/src/common_ops.py b/examples/trials/nas_cifar10/src/common_ops.py
new file mode 100644
index 0000000000..e0933f6e53
--- /dev/null
+++ b/examples/trials/nas_cifar10/src/common_ops.py
@@ -0,0 +1,255 @@
+import numpy as np
+import tensorflow as tf
+from tensorflow.python.training import moving_averages
+
+
+def lstm(x, prev_c, prev_h, w):
+ ifog = tf.matmul(tf.concat([x, prev_h], axis=1), w)
+ i, f, o, g = tf.split(ifog, 4, axis=1)
+ i = tf.sigmoid(i)
+ f = tf.sigmoid(f)
+ o = tf.sigmoid(o)
+ g = tf.tanh(g)
+ next_c = i * g + f * prev_c
+ next_h = o * tf.tanh(next_c)
+ return next_c, next_h
+
+
+def stack_lstm(x, prev_c, prev_h, w):
+ next_c, next_h = [], []
+ for layer_id, (_c, _h, _w) in enumerate(zip(prev_c, prev_h, w)):
+ inputs = x if layer_id == 0 else next_h[-1]
+ curr_c, curr_h = lstm(inputs, _c, _h, _w)
+ next_c.append(curr_c)
+ next_h.append(curr_h)
+ return next_c, next_h
+
+
+def create_weight(name, shape, initializer=None, trainable=True, seed=None):
+ if initializer is None:
+ initializer = tf.contrib.keras.initializers.he_normal(seed=seed)
+ return tf.get_variable(name, shape, initializer=initializer, trainable=trainable)
+
+
+def create_bias(name, shape, initializer=None):
+ if initializer is None:
+ initializer = tf.constant_initializer(0.0, dtype=tf.float32)
+ return tf.get_variable(name, shape, initializer=initializer)
+
+
+def conv_op(inputs, filter_size, is_training, count, out_filters,
+ data_format, ch_mul=1, start_idx=None, separable=False):
+ """
+ Args:
+ start_idx: where to start taking the output channels. if None, assuming
+ fixed_arc mode
+ count: how many output_channels to take.
+ """
+
+ if data_format == "NHWC":
+ inp_c = inputs.get_shape()[3].value
+ elif data_format == "NCHW":
+ inp_c = inputs.get_shape()[1].value
+
+ with tf.variable_scope("inp_conv_1"):
+ w = create_weight("w", [1, 1, inp_c, out_filters])
+ x = tf.nn.conv2d(inputs, w, [1, 1, 1, 1],
+ "SAME", data_format=data_format)
+ x = batch_norm(x, is_training, data_format=data_format)
+ x = tf.nn.relu(x)
+
+ with tf.variable_scope("out_conv_{}".format(filter_size)):
+ if start_idx is None:
+ if separable:
+ w_depth = create_weight(
+ "w_depth", [filter_size, filter_size, out_filters, ch_mul])
+ w_point = create_weight(
+ "w_point", [1, 1, out_filters * ch_mul, count])
+ x = tf.nn.separable_conv2d(x, w_depth, w_point, strides=[1, 1, 1, 1],
+ padding="SAME", data_format=data_format)
+ x = batch_norm(
+ x, is_training, data_format=data_format)
+ else:
+ w = create_weight(
+ "w", [filter_size, filter_size, inp_c, count])
+ x = tf.nn.conv2d(
+ x, w, [1, 1, 1, 1], "SAME", data_format=data_format)
+ x = batch_norm(
+ x, is_training, data_format=data_format)
+ else:
+ if separable:
+ w_depth = create_weight(
+ "w_depth", [filter_size, filter_size, out_filters, ch_mul])
+ #test_depth = w_depth
+ w_point = create_weight(
+ "w_point", [out_filters, out_filters * ch_mul])
+ w_point = w_point[start_idx:start_idx+count, :]
+ w_point = tf.transpose(w_point, [1, 0])
+ w_point = tf.reshape(
+ w_point, [1, 1, out_filters * ch_mul, count])
+
+ x = tf.nn.separable_conv2d(x, w_depth, w_point, strides=[1, 1, 1, 1],
+ padding="SAME", data_format=data_format)
+ mask = tf.range(0, out_filters, dtype=tf.int32)
+ mask = tf.logical_and(
+ start_idx <= mask, mask < start_idx + count)
+ x = batch_norm_with_mask(
+ x, is_training, mask, out_filters, data_format=data_format)
+ else:
+ w = create_weight(
+ "w", [filter_size, filter_size, out_filters, out_filters])
+ w = tf.transpose(w, [3, 0, 1, 2])
+ w = w[start_idx:start_idx+count, :, :, :]
+ w = tf.transpose(w, [1, 2, 3, 0])
+ x = tf.nn.conv2d(
+ x, w, [1, 1, 1, 1], "SAME", data_format=data_format)
+ mask = tf.range(0, out_filters, dtype=tf.int32)
+ mask = tf.logical_and(
+ start_idx <= mask, mask < start_idx + count)
+ x = batch_norm_with_mask(
+ x, is_training, mask, out_filters, data_format=data_format)
+ x = tf.nn.relu(x)
+ return x
+
+def pool_op(inputs, is_training, count, out_filters, avg_or_max, data_format, start_idx=None):
+ """
+ Args:
+ start_idx: where to start taking the output channels. if None, assuming
+ fixed_arc mode
+ count: how many output_channels to take.
+ """
+
+ if data_format == "NHWC":
+ inp_c = inputs.get_shape()[3].value
+ elif data_format == "NCHW":
+ inp_c = inputs.get_shape()[1].value
+
+ with tf.variable_scope("conv_1"):
+ w = create_weight("w", [1, 1, inp_c, out_filters])
+ x = tf.nn.conv2d(inputs, w, [1, 1, 1, 1],
+ "SAME", data_format=data_format)
+ x = batch_norm(x, is_training, data_format=data_format)
+ x = tf.nn.relu(x)
+
+ with tf.variable_scope("pool"):
+ if data_format == "NHWC":
+ actual_data_format = "channels_last"
+ elif data_format == "NCHW":
+ actual_data_format = "channels_first"
+
+ if avg_or_max == "avg":
+ x = tf.layers.average_pooling2d(
+ x, [3, 3], [1, 1], "SAME", data_format=actual_data_format)
+ elif avg_or_max == "max":
+ x = tf.layers.max_pooling2d(
+ x, [3, 3], [1, 1], "SAME", data_format=actual_data_format)
+ else:
+ raise ValueError("Unknown pool {}".format(avg_or_max))
+
+ if start_idx is not None:
+ if data_format == "NHWC":
+ x = x[:, :, :, start_idx: start_idx+count]
+ elif data_format == "NCHW":
+ x = x[:, start_idx: start_idx+count, :, :]
+
+ return x
+
+
+def global_avg_pool(x, data_format="NHWC"):
+ if data_format == "NHWC":
+ x = tf.reduce_mean(x, [1, 2])
+ elif data_format == "NCHW":
+ x = tf.reduce_mean(x, [2, 3])
+ else:
+ raise NotImplementedError("Unknown data_format {}".format(data_format))
+ return x
+
+
+def batch_norm(x, is_training, name="bn", decay=0.9, epsilon=1e-5,
+ data_format="NHWC"):
+ if data_format == "NHWC":
+ shape = [x.get_shape()[3]]
+ elif data_format == "NCHW":
+ shape = [x.get_shape()[1]]
+ else:
+ raise NotImplementedError("Unknown data_format {}".format(data_format))
+
+ with tf.variable_scope(name, reuse=None if is_training else True):
+ offset = tf.get_variable(
+ "offset", shape,
+ initializer=tf.constant_initializer(0.0, dtype=tf.float32))
+ scale = tf.get_variable(
+ "scale", shape,
+ initializer=tf.constant_initializer(1.0, dtype=tf.float32))
+ moving_mean = tf.get_variable(
+ "moving_mean", shape, trainable=False,
+ initializer=tf.constant_initializer(0.0, dtype=tf.float32))
+ moving_variance = tf.get_variable(
+ "moving_variance", shape, trainable=False,
+ initializer=tf.constant_initializer(1.0, dtype=tf.float32))
+
+ if is_training:
+ x, mean, variance = tf.nn.fused_batch_norm(
+ x, scale, offset, epsilon=epsilon, data_format=data_format,
+ is_training=True)
+ update_mean = moving_averages.assign_moving_average(
+ moving_mean, mean, decay)
+ update_variance = moving_averages.assign_moving_average(
+ moving_variance, variance, decay)
+ with tf.control_dependencies([update_mean, update_variance]):
+ x = tf.identity(x)
+ else:
+ x, _, _ = tf.nn.fused_batch_norm(x, scale, offset, mean=moving_mean,
+ variance=moving_variance,
+ epsilon=epsilon, data_format=data_format,
+ is_training=False)
+ return x
+
+
+def batch_norm_with_mask(x, is_training, mask, num_channels, name="bn",
+ decay=0.9, epsilon=1e-3, data_format="NHWC"):
+
+ shape = [num_channels]
+ indices = tf.where(mask)
+ indices = tf.to_int32(indices)
+ indices = tf.reshape(indices, [-1])
+
+ with tf.variable_scope(name, reuse=None if is_training else True):
+ offset = tf.get_variable(
+ "offset", shape,
+ initializer=tf.constant_initializer(0.0, dtype=tf.float32))
+ scale = tf.get_variable(
+ "scale", shape,
+ initializer=tf.constant_initializer(1.0, dtype=tf.float32))
+ offset = tf.boolean_mask(offset, mask)
+ scale = tf.boolean_mask(scale, mask)
+
+ moving_mean = tf.get_variable(
+ "moving_mean", shape, trainable=False,
+ initializer=tf.constant_initializer(0.0, dtype=tf.float32))
+ moving_variance = tf.get_variable(
+ "moving_variance", shape, trainable=False,
+ initializer=tf.constant_initializer(1.0, dtype=tf.float32))
+
+ if is_training:
+ x, mean, variance = tf.nn.fused_batch_norm(
+ x, scale, offset, epsilon=epsilon, data_format=data_format,
+ is_training=True)
+ mean = (1.0 - decay) * (tf.boolean_mask(moving_mean, mask) - mean)
+ variance = (1.0 - decay) * \
+ (tf.boolean_mask(moving_variance, mask) - variance)
+ update_mean = tf.scatter_sub(
+ moving_mean, indices, mean, use_locking=True)
+ update_variance = tf.scatter_sub(
+ moving_variance, indices, variance, use_locking=True)
+ with tf.control_dependencies([update_mean, update_variance]):
+ x = tf.identity(x)
+ else:
+ masked_moving_mean = tf.boolean_mask(moving_mean, mask)
+ masked_moving_variance = tf.boolean_mask(moving_variance, mask)
+ x, _, _ = tf.nn.fused_batch_norm(x, scale, offset,
+ mean=masked_moving_mean,
+ variance=masked_moving_variance,
+ epsilon=epsilon, data_format=data_format,
+ is_training=False)
+ return x
diff --git a/examples/trials/nas_cifar10/src/utils.py b/examples/trials/nas_cifar10/src/utils.py
new file mode 100644
index 0000000000..65d57af7f1
--- /dev/null
+++ b/examples/trials/nas_cifar10/src/utils.py
@@ -0,0 +1,262 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+import numpy as np
+import tensorflow as tf
+
+
+user_flags = []
+
+
+def DEFINE_string(name, default_value, doc_string):
+ tf.app.flags.DEFINE_string(name, default_value, doc_string)
+ global user_flags
+ user_flags.append(name)
+
+
+def DEFINE_integer(name, default_value, doc_string):
+ tf.app.flags.DEFINE_integer(name, default_value, doc_string)
+ global user_flags
+ user_flags.append(name)
+
+
+def DEFINE_float(name, default_value, doc_string):
+ tf.app.flags.DEFINE_float(name, default_value, doc_string)
+ global user_flags
+ user_flags.append(name)
+
+
+def DEFINE_boolean(name, default_value, doc_string):
+ tf.app.flags.DEFINE_boolean(name, default_value, doc_string)
+ global user_flags
+ user_flags.append(name)
+
+
+def print_user_flags(line_limit=80):
+ print("-" * 80)
+
+ global user_flags
+ FLAGS = tf.app.flags.FLAGS
+
+ for flag_name in sorted(user_flags):
+ value = "{}".format(getattr(FLAGS, flag_name))
+ log_string = flag_name
+ log_string += "." * (line_limit - len(flag_name) - len(value))
+ log_string += value
+ print(log_string)
+
+
+def get_C(x, data_format):
+ """
+ Args:
+ x: tensor of shape [N, H, W, C] or [N, C, H, W]
+ """
+ if data_format == "NHWC":
+ return x.get_shape()[3].value
+ elif data_format == "NCHW":
+ return x.get_shape()[1].value
+ else:
+ raise ValueError(
+ "Unknown data_format '{0}'".format(data_format))
+
+def get_HW(x, data_format):
+ """
+ Args:
+ x: tensor of shape [N, H, W, C] or [N, C, H, W]
+ """
+ return x.get_shape()[2].value
+
+def get_strides(stride, data_format):
+ """
+ Args:
+ x: tensor of shape [N, H, W, C] or [N, C, H, W]
+ """
+ if data_format == "NHWC":
+ return [1, stride, stride, 1]
+ elif data_format == "NCHW":
+ return [1, 1, stride, stride]
+ else:
+ raise ValueError(
+ "Unknown data_format '{0}'".format(data_format))
+
+
+class TextColors:
+ HEADER = '\033[95m'
+ OKBLUE = '\033[94m'
+ OKGREEN = '\033[92m'
+ WARNING = '\033[93m'
+ FAIL = '\033[91m'
+ ENDC = '\033[0m'
+ BOLD = '\033[1m'
+ UNDERLINE = '\033[4m'
+
+
+class Logger(object):
+ def __init__(self, output_file):
+ self.terminal = sys.stdout
+ self.log = open(output_file, "a")
+
+ def write(self, message):
+ self.terminal.write(message)
+ self.terminal.flush()
+ self.log.write(message)
+ self.log.flush()
+
+
+def count_model_params(tf_variables):
+ """
+ Args:
+ tf_variables: list of all model variables
+ """
+
+ num_vars = 0
+ for var in tf_variables:
+ num_vars += np.prod([dim.value for dim in var.get_shape()])
+ return num_vars
+
+
+def get_train_ops(
+ loss,
+ tf_variables,
+ train_step,
+ clip_mode=None,
+ grad_bound=None,
+ l2_reg=1e-4,
+ lr_warmup_val=None,
+ lr_warmup_steps=100,
+ lr_init=0.1,
+ lr_dec_start=0,
+ lr_dec_every=10000,
+ lr_dec_rate=0.1,
+ lr_dec_min=None,
+ lr_cosine=False,
+ lr_max=None,
+ lr_min=None,
+ lr_T_0=None,
+ lr_T_mul=None,
+ num_train_batches=None,
+ optim_algo=None,
+ sync_replicas=False,
+ num_aggregate=None,
+ num_replicas=None,
+ get_grad_norms=False,
+ moving_average=None):
+ """
+ Args:
+ clip_mode: "global", "norm", or None.
+ moving_average: store the moving average of parameters
+ """
+
+ if l2_reg > 0:
+ l2_losses = []
+ for var in tf_variables:
+ l2_losses.append(tf.reduce_sum(var ** 2))
+ l2_loss = tf.add_n(l2_losses)
+ loss += l2_reg * l2_loss
+
+ grads = tf.gradients(loss, tf_variables)
+ grad_norm = tf.global_norm(grads)
+
+ grad_norms = {}
+ for v, g in zip(tf_variables, grads):
+ if v is None or g is None:
+ continue
+ if isinstance(g, tf.IndexedSlices):
+ grad_norms[v.name] = tf.sqrt(tf.reduce_sum(g.values ** 2))
+ else:
+ grad_norms[v.name] = tf.sqrt(tf.reduce_sum(g ** 2))
+
+ if clip_mode is not None:
+ assert grad_bound is not None, "Need grad_bound to clip gradients."
+ if clip_mode == "global":
+ grads, _ = tf.clip_by_global_norm(grads, grad_bound)
+ elif clip_mode == "norm":
+ clipped = []
+ for g in grads:
+ if isinstance(g, tf.IndexedSlices):
+ c_g = tf.clip_by_norm(g.values, grad_bound)
+ c_g = tf.IndexedSlices(g.indices, c_g)
+ else:
+ c_g = tf.clip_by_norm(g, grad_bound)
+ clipped.append(g)
+ grads = clipped
+ else:
+ raise NotImplementedError("Unknown clip_mode {}".format(clip_mode))
+
+ if lr_cosine:
+ assert lr_max is not None, "Need lr_max to use lr_cosine"
+ assert lr_min is not None, "Need lr_min to use lr_cosine"
+ assert lr_T_0 is not None, "Need lr_T_0 to use lr_cosine"
+ assert lr_T_mul is not None, "Need lr_T_mul to use lr_cosine"
+ assert num_train_batches is not None, ("Need num_train_batches to use"
+ " lr_cosine")
+
+ curr_epoch = train_step // num_train_batches
+
+ last_reset = tf.Variable(0, dtype=tf.int32, trainable=False,
+ name="last_reset")
+ T_i = tf.Variable(lr_T_0, dtype=tf.int32, trainable=False, name="T_i")
+ T_curr = curr_epoch - last_reset
+
+ def _update():
+ update_last_reset = tf.assign(
+ last_reset, curr_epoch, use_locking=True)
+ update_T_i = tf.assign(T_i, T_i * lr_T_mul, use_locking=True)
+ with tf.control_dependencies([update_last_reset, update_T_i]):
+ rate = tf.to_float(T_curr) / tf.to_float(T_i) * 3.1415926
+ lr = lr_min + 0.5 * (lr_max - lr_min) * (1.0 + tf.cos(rate))
+ return lr
+
+ def _no_update():
+ rate = tf.to_float(T_curr) / tf.to_float(T_i) * 3.1415926
+ lr = lr_min + 0.5 * (lr_max - lr_min) * (1.0 + tf.cos(rate))
+ return lr
+
+ learning_rate = tf.cond(
+ tf.greater_equal(T_curr, T_i), _update, _no_update)
+ else:
+ learning_rate = tf.train.exponential_decay(
+ lr_init, tf.maximum(train_step - lr_dec_start, 0), lr_dec_every,
+ lr_dec_rate, staircase=True)
+ if lr_dec_min is not None:
+ learning_rate = tf.maximum(learning_rate, lr_dec_min)
+
+ if lr_warmup_val is not None:
+ learning_rate = tf.cond(tf.less(train_step, lr_warmup_steps),
+ lambda: lr_warmup_val, lambda: learning_rate)
+
+ if optim_algo == "momentum":
+ opt = tf.train.MomentumOptimizer(
+ learning_rate, 0.9, use_locking=True, use_nesterov=True)
+ elif optim_algo == "sgd":
+ opt = tf.train.GradientDescentOptimizer(
+ learning_rate, use_locking=True)
+ elif optim_algo == "adam":
+ opt = tf.train.AdamOptimizer(learning_rate, beta1=0.0, epsilon=1e-3,
+ use_locking=True)
+ else:
+ raise ValueError("Unknown optim_algo {}".format(optim_algo))
+
+ if sync_replicas:
+ assert num_aggregate is not None, "Need num_aggregate to sync."
+ assert num_replicas is not None, "Need num_replicas to sync."
+
+ opt = tf.train.SyncReplicasOptimizer(
+ opt,
+ replicas_to_aggregate=num_aggregate,
+ total_num_replicas=num_replicas,
+ use_locking=True)
+
+ if moving_average is not None:
+ opt = tf.contrib.opt.MovingAverageOptimizer(
+ opt, average_decay=moving_average)
+
+ train_op = opt.apply_gradients(
+ zip(grads, tf_variables), global_step=train_step)
+
+ if get_grad_norms:
+ return train_op, learning_rate, grad_norm, opt, grad_norms
+ else:
+ return train_op, learning_rate, grad_norm, opt
diff --git a/examples/tuners/random_nas_tuner/random_nas_tuner.py b/examples/tuners/random_nas_tuner/random_nas_tuner.py
index d7f6214aa6..c13bc72c6b 100644
--- a/examples/tuners/random_nas_tuner/random_nas_tuner.py
+++ b/examples/tuners/random_nas_tuner/random_nas_tuner.py
@@ -2,13 +2,14 @@
from nni.tuner import Tuner
+
def random_archi_generator(nas_ss, random_state):
'''random
'''
chosen_archi = {}
- print("zql: nas search space: ", nas_ss)
for block_name, block_value in nas_ss.items():
- assert block_value['_type'] == "mutable_layer", "Random NAS Tuner only receives NAS search space whose _type is 'mutable_layer'"
+ assert block_value['_type'] == "mutable_layer", \
+ "Random NAS Tuner only receives NAS search space whose _type is 'mutable_layer'"
block = block_value['_value']
tmp_block = {}
for layer_name, layer in block.items():
@@ -19,13 +20,12 @@ def random_archi_generator(nas_ss, random_state):
tmp_layer['chosen_layer'] = value[index]
elif key == 'optional_inputs':
tmp_layer['chosen_inputs'] = []
- print("zql: optional_inputs", layer['optional_inputs'])
if layer['optional_inputs']:
if isinstance(layer['optional_input_size'], int):
choice_num = layer['optional_input_size']
else:
choice_range = layer['optional_input_size']
- choice_num = random_state.randint(choice_range[0], choice_range[1]+1)
+ choice_num = random_state.randint(choice_range[0], choice_range[1] + 1)
for _ in range(choice_num):
index = random_state.randint(len(layer['optional_inputs']))
tmp_layer['chosen_inputs'].append(layer['optional_inputs'][index])
@@ -37,6 +37,7 @@ def random_archi_generator(nas_ss, random_state):
chosen_archi[block_name] = tmp_block
return chosen_archi
+
class RandomNASTuner(Tuner):
'''RandomNASTuner
'''
diff --git a/src/nni_manager/common/experimentStartupInfo.ts b/src/nni_manager/common/experimentStartupInfo.ts
index c0c0c7a2a7..5675facdde 100644
--- a/src/nni_manager/common/experimentStartupInfo.ts
+++ b/src/nni_manager/common/experimentStartupInfo.ts
@@ -30,14 +30,13 @@ class ExperimentStartupInfo {
private newExperiment: boolean = true;
private basePort: number = -1;
private initialized: boolean = false;
- private initTrialSequenceID: number = 0;
private logDir: string = '';
private logLevel: string = '';
+ private readonly: boolean = false;
- public setStartupInfo(newExperiment: boolean, experimentId: string, basePort: number, logDir?: string, logLevel?: string): void {
+ public setStartupInfo(newExperiment: boolean, experimentId: string, basePort: number, logDir?: string, logLevel?: string, readonly?: boolean): void {
assert(!this.initialized);
assert(experimentId.trim().length > 0);
-
this.newExperiment = newExperiment;
this.experimentId = experimentId;
this.basePort = basePort;
@@ -52,6 +51,10 @@ class ExperimentStartupInfo {
if (logLevel !== undefined && logLevel.length > 1) {
this.logLevel = logLevel;
}
+
+ if (readonly !== undefined) {
+ this.readonly = readonly;
+ }
}
public getExperimentId(): string {
@@ -84,15 +87,10 @@ class ExperimentStartupInfo {
return this.logLevel;
}
- public setInitTrialSequenceId(initSequenceId: number): void {
- assert(this.initialized);
- this.initTrialSequenceID = initSequenceId;
- }
-
- public getInitTrialSequenceId(): number {
+ public isReadonly(): boolean {
assert(this.initialized);
- return this.initTrialSequenceID;
+ return this.readonly;
}
}
@@ -108,23 +106,19 @@ function isNewExperiment(): boolean {
return component.get(ExperimentStartupInfo).isNewExperiment();
}
-function setInitTrialSequenceId(initSequenceId: number): void {
- component.get(ExperimentStartupInfo).setInitTrialSequenceId(initSequenceId);
-}
-
-function getInitTrialSequenceId(): number {
- return component.get(ExperimentStartupInfo).getInitTrialSequenceId();
-}
-
function getExperimentStartupInfo(): ExperimentStartupInfo {
return component.get(ExperimentStartupInfo);
}
function setExperimentStartupInfo(
- newExperiment: boolean, experimentId: string, basePort: number, logDir?: string, logLevel?: string): void {
+ newExperiment: boolean, experimentId: string, basePort: number, logDir?: string, logLevel?: string, readonly?: boolean): void {
component.get(ExperimentStartupInfo)
- .setStartupInfo(newExperiment, experimentId, basePort, logDir, logLevel);
+ .setStartupInfo(newExperiment, experimentId, basePort, logDir, logLevel, readonly);
+}
+
+function isReadonly(): boolean {
+ return component.get(ExperimentStartupInfo).isReadonly();
}
export { ExperimentStartupInfo, getBasePort, getExperimentId, isNewExperiment, getExperimentStartupInfo,
- setExperimentStartupInfo, setInitTrialSequenceId, getInitTrialSequenceId };
+ setExperimentStartupInfo, isReadonly };
diff --git a/src/nni_manager/common/log.ts b/src/nni_manager/common/log.ts
index d12e235836..e2ca62f9c6 100644
--- a/src/nni_manager/common/log.ts
+++ b/src/nni_manager/common/log.ts
@@ -26,7 +26,7 @@ import { Writable } from 'stream';
import { WritableStreamBuffer } from 'stream-buffers';
import { format } from 'util';
import * as component from '../common/component';
-import { getExperimentStartupInfo } from './experimentStartupInfo';
+import { getExperimentStartupInfo, isReadonly } from './experimentStartupInfo';
import { getLogDir } from './utils';
const FATAL: number = 1;
@@ -76,6 +76,7 @@ class Logger {
private level: number = INFO;
private bufferSerialEmitter: BufferSerialEmitter;
private writable: Writable;
+ private readonly: boolean = false;
constructor(fileName?: string) {
let logFile: string | undefined = fileName;
@@ -95,6 +96,8 @@ class Logger {
if (logLevel !== undefined) {
this.level = logLevel;
}
+
+ this.readonly = isReadonly();
}
public close() {
@@ -134,14 +137,21 @@ class Logger {
public fatal(...param: any[]): void {
this.log('FATAL', param);
}
-
+
+ /**
+ * if the experiment is not in readonly mode, write log content to stream
+ * @param level log level
+ * @param param the params to be written
+ */
private log(level: string, param: any[]): void {
- const buffer: WritableStreamBuffer = new WritableStreamBuffer();
- buffer.write(`[${(new Date()).toLocaleString()}] ${level} `);
- buffer.write(format(param));
- buffer.write('\n');
- buffer.end();
- this.bufferSerialEmitter.feed(buffer.getContents());
+ if (!this.readonly) {
+ const buffer: WritableStreamBuffer = new WritableStreamBuffer();
+ buffer.write(`[${(new Date()).toLocaleString()}] ${level} `);
+ buffer.write(format(param));
+ buffer.write('\n');
+ buffer.end();
+ this.bufferSerialEmitter.feed(buffer.getContents());
+ }
}
}
diff --git a/src/nni_manager/common/manager.ts b/src/nni_manager/common/manager.ts
index 4933465b92..65ab4b77ed 100644
--- a/src/nni_manager/common/manager.ts
+++ b/src/nni_manager/common/manager.ts
@@ -24,6 +24,10 @@ import { TrialJobStatus } from './trainingService';
type ProfileUpdateType = 'TRIAL_CONCURRENCY' | 'MAX_EXEC_DURATION' | 'SEARCH_SPACE' | 'MAX_TRIAL_NUM';
type ExperimentStatus = 'INITIALIZED' | 'RUNNING' | 'ERROR' | 'STOPPING' | 'STOPPED' | 'DONE' | 'NO_MORE_TRIAL' | 'TUNER_NO_MORE_TRIAL';
+namespace ExperimentStartUpMode {
+ export const NEW = 'new';
+ export const RESUME = 'resume';
+}
interface ExperimentParams {
authorName: string;
@@ -79,7 +83,7 @@ interface ExperimentProfile {
logDir?: string;
startTime?: number;
endTime?: number;
- maxSequenceId: number;
+ nextSequenceId: number;
revision: number;
}
@@ -95,7 +99,7 @@ interface NNIManagerStatus {
abstract class Manager {
public abstract startExperiment(experimentParams: ExperimentParams): Promise;
- public abstract resumeExperiment(): Promise;
+ public abstract resumeExperiment(readonly: boolean): Promise;
public abstract stopExperiment(): Promise;
public abstract getExperimentProfile(): Promise;
public abstract updateExperimentProfile(experimentProfile: ExperimentProfile, updateType: ProfileUpdateType): Promise;
@@ -111,8 +115,11 @@ abstract class Manager {
public abstract getClusterMetadata(key: string): Promise;
public abstract getMetricData(trialJobId?: string, metricType?: MetricType): Promise;
+ public abstract getMetricDataByRange(minSeqId: number, maxSeqId: number): Promise;
+ public abstract getLatestMetricData(): Promise;
+
public abstract getTrialJobStatistics(): Promise;
public abstract getStatus(): NNIManagerStatus;
}
-export { Manager, ExperimentParams, ExperimentProfile, TrialJobStatistics, ProfileUpdateType, NNIManagerStatus, ExperimentStatus };
+export { Manager, ExperimentParams, ExperimentProfile, TrialJobStatistics, ProfileUpdateType, NNIManagerStatus, ExperimentStatus, ExperimentStartUpMode };
diff --git a/src/nni_manager/common/trainingService.ts b/src/nni_manager/common/trainingService.ts
index e4a9f1547e..2dfa0a9589 100644
--- a/src/nni_manager/common/trainingService.ts
+++ b/src/nni_manager/common/trainingService.ts
@@ -23,20 +23,12 @@
* define TrialJobStatus
*/
type TrialJobStatus = 'UNKNOWN' | 'WAITING' | 'RUNNING' | 'SUCCEEDED' | 'FAILED' | 'USER_CANCELED' | 'SYS_CANCELED' | 'EARLY_STOPPED';
-type JobType = 'TRIAL' | 'HOST';
interface TrainingServiceMetadata {
readonly key: string;
readonly value: string;
}
-/**
- * define JobApplicationForm
- */
-interface JobApplicationForm {
- readonly jobType: JobType;
-}
-
interface HyperParameters {
readonly value: string;
readonly index: number;
@@ -45,18 +37,11 @@ interface HyperParameters {
/**
* define TrialJobApplicationForm
*/
-interface TrialJobApplicationForm extends JobApplicationForm {
+interface TrialJobApplicationForm {
+ readonly sequenceId: number;
readonly hyperParameters: HyperParameters;
}
-/**
- * define HostJobApplicationForm
- */
-interface HostJobApplicationForm extends JobApplicationForm {
- readonly host: string;
- readonly cmd: string;
-}
-
/**
* define TrialJobDetail
*/
@@ -69,8 +54,7 @@ interface TrialJobDetail {
readonly tags?: string[];
readonly url?: string;
readonly workingDirectory: string;
- readonly form: JobApplicationForm;
- readonly sequenceId: number;
+ readonly form: TrialJobApplicationForm;
isEarlyStopped?: boolean;
}
@@ -112,8 +96,8 @@ abstract class TrainingService {
public abstract getTrialJob(trialJobId: string): Promise;
public abstract addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void;
public abstract removeTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void;
- public abstract submitTrialJob(form: JobApplicationForm): Promise;
- public abstract updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise;
+ public abstract submitTrialJob(form: TrialJobApplicationForm): Promise;
+ public abstract updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise;
public abstract get isMultiPhaseJobSupported(): boolean;
public abstract cancelTrialJob(trialJobId: string, isEarlyStopped?: boolean): Promise;
public abstract setClusterMetadata(key: string, value: string): Promise;
@@ -135,5 +119,5 @@ class NNIManagerIpConfig {
export {
TrainingService, TrainingServiceError, TrialJobStatus, TrialJobApplicationForm,
TrainingServiceMetadata, TrialJobDetail, TrialJobMetric, HyperParameters,
- HostJobApplicationForm, JobApplicationForm, JobType, NNIManagerIpConfig
+ NNIManagerIpConfig
};
diff --git a/src/nni_manager/common/utils.ts b/src/nni_manager/common/utils.ts
index 99d293c6a3..5ae7fc80cb 100644
--- a/src/nni_manager/common/utils.ts
+++ b/src/nni_manager/common/utils.ts
@@ -219,6 +219,11 @@ function getMsgDispatcherCommand(tuner: any, assessor: any, advisor: any, multiP
if (advisor.classFileName !== undefined && advisor.classFileName.length > 1) {
command += ` --advisor_class_filename ${advisor.classFileName}`;
}
+ if (advisor.gpuIndices !== undefined) {
+ command = `CUDA_VISIBLE_DEVICES=${advisor.gpuIndices} ` + command;
+ } else {
+ command = `CUDA_VISIBLE_DEVICES='' ` + command;
+ }
} else {
command += ` --tuner_class_name ${tuner.className}`;
if (tuner.classArgs !== undefined) {
@@ -243,6 +248,12 @@ function getMsgDispatcherCommand(tuner: any, assessor: any, advisor: any, multiP
command += ` --assessor_class_filename ${assessor.classFileName}`;
}
}
+
+ if (tuner.gpuIndices !== undefined) {
+ command = `CUDA_VISIBLE_DEVICES=${tuner.gpuIndices} ` + command;
+ } else {
+ command = `CUDA_VISIBLE_DEVICES='' ` + command;
+ }
}
return command;
diff --git a/src/nni_manager/core/nnimanager.ts b/src/nni_manager/core/nnimanager.ts
index c516117ca5..adbdfcda34 100644
--- a/src/nni_manager/core/nnimanager.ts
+++ b/src/nni_manager/core/nnimanager.ts
@@ -26,7 +26,7 @@ import { Deferred } from 'ts-deferred';
import * as component from '../common/component';
import { DataStore, MetricDataRecord, MetricType, TrialJobInfo } from '../common/datastore';
import { NNIError } from '../common/errors';
-import { getExperimentId, setInitTrialSequenceId } from '../common/experimentStartupInfo';
+import { getExperimentId } from '../common/experimentStartupInfo';
import { getLogger, Logger } from '../common/log';
import {
ExperimentParams, ExperimentProfile, Manager, ExperimentStatus,
@@ -59,9 +59,10 @@ class NNIManager implements Manager {
private waitingTrials: string[];
private trialJobs: Map;
private trialDataForTuner: string;
+ private readonly: boolean;
private trialJobMetricListener: (metric: TrialJobMetric) => void;
-
+
constructor() {
this.currSubmittedTrialNum = 0;
this.trialConcurrencyChange = 0;
@@ -72,6 +73,7 @@ class NNIManager implements Manager {
this.waitingTrials = [];
this.trialJobs = new Map();
this.trialDataForTuner = '';
+ this.readonly = false;
this.log = getLogger();
this.dataStore = component.get(DataStore);
@@ -88,6 +90,9 @@ class NNIManager implements Manager {
}
public updateExperimentProfile(experimentProfile: ExperimentProfile, updateType: ProfileUpdateType): Promise {
+ if (this.readonly) {
+ return Promise.reject(new Error('Error: can not update experiment profile in readonly mode!'));
+ }
switch (updateType) {
case 'TRIAL_CONCURRENCY':
this.updateTrialConcurrency(experimentProfile.params.trialConcurrency);
@@ -109,6 +114,9 @@ class NNIManager implements Manager {
}
public importData(data: string): Promise {
+ if (this.readonly) {
+ return Promise.reject(new Error('Error: can not import data in readonly mode!'));
+ }
if (this.dispatcher === undefined) {
return Promise.reject(
new Error('tuner has not been setup')
@@ -124,6 +132,9 @@ class NNIManager implements Manager {
}
public addCustomizedTrialJob(hyperParams: string): Promise {
+ if (this.readonly) {
+ return Promise.reject(new Error('Error: can not add customized trial job in readonly mode!'));
+ }
if (this.currSubmittedTrialNum >= this.experimentProfile.params.maxTrialNum) {
return Promise.reject(
new Error('reach maxTrialNum')
@@ -136,6 +147,9 @@ class NNIManager implements Manager {
}
public async cancelTrialJobByUser(trialJobId: string): Promise {
+ if (this.readonly) {
+ return Promise.reject(new Error('Error: can not cancel trial job in readonly mode!'));
+ }
this.log.info(`User cancelTrialJob: ${trialJobId}`);
await this.trainingService.cancelTrialJob(trialJobId);
await this.dataStore.storeTrialJobEvent('USER_TO_CANCEL', trialJobId, '');
@@ -180,15 +194,17 @@ class NNIManager implements Manager {
return this.experimentProfile.id;
}
- public async resumeExperiment(): Promise {
+ public async resumeExperiment(readonly: boolean): Promise {
this.log.info(`Resuming experiment: ${this.experimentProfile.id}`);
//Fetch back the experiment profile
const experimentId: string = getExperimentId();
this.experimentProfile = await this.dataStore.getExperimentProfile(experimentId);
+ this.readonly = readonly;
+ if (readonly) {
+ return Promise.resolve();
+ }
const expParams: ExperimentParams = this.experimentProfile.params;
- setInitTrialSequenceId(this.experimentProfile.maxSequenceId + 1);
-
// Set up multiphase config
if (expParams.multiPhase && this.trainingService.isMultiPhaseJobSupported) {
this.trainingService.setClusterMetadata('multiPhase', expParams.multiPhase.toString());
@@ -196,7 +212,7 @@ class NNIManager implements Manager {
// Set up versionCheck config
if (expParams.versionCheck !== undefined) {
- this.trainingService.setClusterMetadata('versionCheck', expParams.versionCheck.toString());
+ this.trainingService.setClusterMetadata('version_check', expParams.versionCheck.toString());
}
const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.advisor,
@@ -247,6 +263,9 @@ class NNIManager implements Manager {
}
public async setClusterMetadata(key: string, value: string): Promise {
+ if (this.readonly) {
+ return Promise.reject(new Error('Error: can not set cluster metadata in readonly mode!'));
+ }
this.log.info(`NNIManager setClusterMetadata, key: ${key}, value: ${value}`);
let timeoutId: NodeJS.Timer;
// TO DO: move timeout value to constants file
@@ -281,6 +300,37 @@ class NNIManager implements Manager {
return this.dataStore.getMetricData(trialJobId, metricType);
}
+ public async getMetricDataByRange(minSeqId: number, maxSeqId: number): Promise {
+ const trialJobs = await this.dataStore.listTrialJobs();
+ const targetTrials = trialJobs.filter(trial => (
+ // FIXME: can this be undefined?
+ trial.sequenceId !== undefined && minSeqId <= trial.sequenceId && trial.sequenceId <= maxSeqId
+ ));
+ const targetTrialIds = new Set(targetTrials.map(trial => trial.id));
+
+ const allMetrics = await this.dataStore.getMetricData();
+ return allMetrics.filter(metric => targetTrialIds.has(metric.trialJobId));
+ }
+
+ public async getLatestMetricData(): Promise {
+ // FIXME: this can take a long time
+ const allMetrics: MetricDataRecord[] = await this.dataStore.getMetricData();
+ const finals: MetricDataRecord[] = [];
+ const latestIntermediates: Map = new Map();
+ for (const metric of allMetrics) {
+ if (metric.type !== 'PERIODICAL') {
+ finals.push(metric);
+ } else {
+ const old: MetricDataRecord | undefined = latestIntermediates.get(metric.trialJobId);
+ if (old === undefined || old.sequence <= metric.sequence) {
+ latestIntermediates.set(metric.trialJobId, metric);
+ }
+ }
+ }
+ return finals.concat(Array.from(latestIntermediates.values()));
+ // FIXME: unit test
+ }
+
public getExperimentProfile(): Promise {
// TO DO: using Promise.resolve()
const deferred: Deferred = new Deferred();
@@ -363,7 +413,7 @@ class NNIManager implements Manager {
if (this.dispatcher === undefined) {
throw new Error('Error: tuner has not been setup');
}
- this.trainingService.removeTrialJobMetricListener(this.trialJobMetricListener);
+ this.trainingService.removeTrialJobMetricListener(this.trialJobMetricListener);
this.dispatcher.sendCommand(TERMINATE);
let tunerAlive: boolean = true;
// gracefully terminate tuner and assessor here, wait at most 30 seconds.
@@ -436,11 +486,7 @@ class NNIManager implements Manager {
case 'EARLY_STOPPED':
this.trialJobs.delete(trialJobId);
finishedTrialJobNum++;
- if (trialJobDetail.form.jobType === 'TRIAL') {
- hyperParams = (trialJobDetail.form).hyperParameters.value;
- } else {
- throw new Error('Error: jobType error, not TRIAL');
- }
+ hyperParams = trialJobDetail.form.hyperParameters.value;
this.dispatcher.sendCommand(TRIAL_END, JSON.stringify({
trial_job_id: trialJobDetail.id,
event: trialJobDetail.status,
@@ -453,11 +499,7 @@ class NNIManager implements Manager {
// TO DO: push this job to queue for retry
this.trialJobs.delete(trialJobId);
finishedTrialJobNum++;
- if (trialJobDetail.form.jobType === 'TRIAL') {
- hyperParams = (trialJobDetail.form).hyperParameters.value;
- } else {
- throw new Error('Error: jobType error, not TRIAL');
- }
+ hyperParams = trialJobDetail.form.hyperParameters.value;
this.dispatcher.sendCommand(TRIAL_END, JSON.stringify({
trial_job_id: trialJobDetail.id,
event: trialJobDetail.status,
@@ -556,7 +598,7 @@ class NNIManager implements Manager {
}
this.currSubmittedTrialNum++;
const trialJobAppForm: TrialJobApplicationForm = {
- jobType: 'TRIAL',
+ sequenceId: this.experimentProfile.nextSequenceId++,
hyperParameters: {
value: hyperParams,
index: 0
@@ -564,7 +606,7 @@ class NNIManager implements Manager {
};
this.log.info(`submitTrialJob: form: ${JSON.stringify(trialJobAppForm)}`);
const trialJobDetail: TrialJobDetail = await this.trainingService.submitTrialJob(trialJobAppForm);
- await this.storeMaxSequenceId(trialJobDetail.sequenceId);
+ await this.storeExperimentProfile();
this.trialJobs.set(trialJobDetail.id, Object.assign({}, trialJobDetail));
const trialJobDetailSnapshot: TrialJobDetail | undefined = this.trialJobs.get(trialJobDetail.id);
if (trialJobDetailSnapshot != undefined) {
@@ -683,7 +725,7 @@ class NNIManager implements Manager {
assert(tunerCommand.trial_job_id !== undefined);
const trialJobForm: TrialJobApplicationForm = {
- jobType: 'TRIAL',
+ sequenceId: -1, // FIXME: multi-phase tuner should use sequence ID instead of trial job ID
hyperParameters: {
value: content,
index: tunerCommand.parameter_index
@@ -691,8 +733,11 @@ class NNIManager implements Manager {
};
this.log.info(`updateTrialJob: job id: ${tunerCommand.trial_job_id}, form: ${JSON.stringify(trialJobForm)}`);
await this.trainingService.updateTrialJob(tunerCommand.trial_job_id, trialJobForm);
- await this.dataStore.storeTrialJobEvent(
- 'ADD_HYPERPARAMETER', tunerCommand.trial_job_id, content, undefined);
+ if (tunerCommand['parameters'] !== null) {
+ // parameters field is set as empty string if no more hyper parameter can be generated by tuner.
+ await this.dataStore.storeTrialJobEvent(
+ 'ADD_HYPERPARAMETER', tunerCommand.trial_job_id, content, undefined);
+ }
break;
case NO_MORE_TRIAL_JOBS:
if (!['ERROR', 'STOPPING', 'STOPPED'].includes(this.status.status)) {
@@ -734,7 +779,7 @@ class NNIManager implements Manager {
revision: 0,
execDuration: 0,
logDir: getExperimentRootDir(),
- maxSequenceId: 0,
+ nextSequenceId: 0,
params: {
authorName: '',
experimentName: '',
@@ -765,13 +810,6 @@ class NNIManager implements Manager {
return Promise.resolve(chkpDir);
}
-
- private async storeMaxSequenceId(sequenceId: number): Promise {
- if (sequenceId > this.experimentProfile.maxSequenceId) {
- this.experimentProfile.maxSequenceId = sequenceId;
- await this.storeExperimentProfile();
- }
- }
}
export { NNIManager };
diff --git a/src/nni_manager/core/sqlDatabase.ts b/src/nni_manager/core/sqlDatabase.ts
index 182011dcba..8ae29413bc 100644
--- a/src/nni_manager/core/sqlDatabase.ts
+++ b/src/nni_manager/core/sqlDatabase.ts
@@ -54,7 +54,7 @@ create table ExperimentProfile (
startTime integer,
endTime integer,
logDir text,
- maxSequenceId integer,
+ nextSequenceId integer,
revision integer);
create index ExperimentProfile_id on ExperimentProfile(id);
`;
@@ -67,7 +67,7 @@ function loadExperimentProfile(row: any): ExperimentProfile {
startTime: row.startTime === null ? undefined : row.startTime,
endTime: row.endTime === null ? undefined : row.endTime,
logDir: row.logDir === null ? undefined : row.logDir,
- maxSequenceId: row.maxSequenceId,
+ nextSequenceId: row.nextSequenceId,
revision: row.revision
};
}
@@ -144,7 +144,7 @@ class SqlDB implements Database {
exp.startTime === undefined ? null : exp.startTime,
exp.endTime === undefined ? null : exp.endTime,
exp.logDir === undefined ? null : exp.logDir,
- exp.maxSequenceId,
+ exp.nextSequenceId,
exp.revision
];
this.log.trace(`storeExperimentProfile: SQL: ${sql}, args: ${JSON.stringify(args)}`);
@@ -183,7 +183,7 @@ class SqlDB implements Database {
event: TrialJobEvent, trialJobId: string, timestamp: number, hyperParameter?: string, jobDetail?: TrialJobDetail): Promise {
const sql: string = 'insert into TrialJobEvent values (?,?,?,?,?,?)';
const logPath: string | undefined = jobDetail === undefined ? undefined : jobDetail.url;
- const sequenceId: number | undefined = jobDetail === undefined ? undefined : jobDetail.sequenceId;
+ const sequenceId: number | undefined = jobDetail === undefined ? undefined : jobDetail.form.sequenceId;
const args: any[] = [timestamp, trialJobId, event, hyperParameter, logPath, sequenceId];
this.log.trace(`storeTrialJobEvent: SQL: ${sql}, args: ${JSON.stringify(args)}`);
diff --git a/src/nni_manager/core/test/dataStore.test.ts b/src/nni_manager/core/test/dataStore.test.ts
index d0303990bb..6794706672 100644
--- a/src/nni_manager/core/test/dataStore.test.ts
+++ b/src/nni_manager/core/test/dataStore.test.ts
@@ -80,7 +80,7 @@ describe('Unit test for dataStore', () => {
execDuration: 0,
startTime: Date.now(),
endTime: Date.now(),
- maxSequenceId: 0,
+ nextSequenceId: 0,
revision: 0
}
const id: string = profile.id;
diff --git a/src/nni_manager/core/test/mockedTrainingService.ts b/src/nni_manager/core/test/mockedTrainingService.ts
index 027234de9e..f50fb62113 100644
--- a/src/nni_manager/core/test/mockedTrainingService.ts
+++ b/src/nni_manager/core/test/mockedTrainingService.ts
@@ -41,9 +41,9 @@ class MockedTrainingService extends TrainingService {
url: 'http://test',
workingDirectory: '/tmp/mocked',
form: {
- jobType: 'TRIAL'
+ sequenceId: 0,
+ hyperParameters: { value: '', index: 0 }
},
- sequenceId: 0
};
public jobDetail2: TrialJobDetail = {
id: '3456',
@@ -55,9 +55,9 @@ class MockedTrainingService extends TrainingService {
url: 'http://test',
workingDirectory: '/tmp/mocked',
form: {
- jobType: 'TRIAL'
+ sequenceId: 1,
+ hyperParameters: { value: '', index: 1 }
},
- sequenceId: 0
};
public listTrialJobs(): Promise {
diff --git a/src/nni_manager/core/test/nnimanager.test.ts b/src/nni_manager/core/test/nnimanager.test.ts
index 1b22ba7315..2eac8b1c8c 100644
--- a/src/nni_manager/core/test/nnimanager.test.ts
+++ b/src/nni_manager/core/test/nnimanager.test.ts
@@ -101,7 +101,7 @@ describe('Unit test for nnimanager', function () {
params: updateExperimentParams,
id: 'test',
execDuration: 0,
- maxSequenceId: 0,
+ nextSequenceId: 0,
revision: 0
}
diff --git a/src/nni_manager/core/test/sqlDatabase.test.ts b/src/nni_manager/core/test/sqlDatabase.test.ts
index f48e0d978e..d292776a3c 100644
--- a/src/nni_manager/core/test/sqlDatabase.test.ts
+++ b/src/nni_manager/core/test/sqlDatabase.test.ts
@@ -64,10 +64,10 @@ const expParams2: ExperimentParams = {
};
const profiles: ExperimentProfile[] = [
- { params: expParams1, id: '#1', execDuration: 0, logDir: '/log', startTime: Date.now(), endTime: undefined, maxSequenceId: 0, revision: 1,},
- { params: expParams1, id: '#1', execDuration: 0, logDir: '/log', startTime: Date.now(), endTime: Date.now(), maxSequenceId: 0, revision: 2 },
- { params: expParams2, id: '#2', execDuration: 0, logDir: '/log', startTime: Date.now(), endTime: Date.now(), maxSequenceId: 0, revision: 2 },
- { params: expParams2, id: '#2', execDuration: 0, logDir: '/log', startTime: Date.now(), endTime: Date.now(), maxSequenceId: 0, revision: 3 }
+ { params: expParams1, id: '#1', execDuration: 0, logDir: '/log', startTime: Date.now(), endTime: undefined, nextSequenceId: 0, revision: 1,},
+ { params: expParams1, id: '#1', execDuration: 0, logDir: '/log', startTime: Date.now(), endTime: Date.now(), nextSequenceId: 1, revision: 2 },
+ { params: expParams2, id: '#2', execDuration: 0, logDir: '/log', startTime: Date.now(), endTime: Date.now(), nextSequenceId: 0, revision: 2 },
+ { params: expParams2, id: '#2', execDuration: 0, logDir: '/log', startTime: Date.now(), endTime: Date.now(), nextSequenceId: 2, revision: 3 }
];
const events: TrialJobEventRecord[] = [
diff --git a/src/nni_manager/main.ts b/src/nni_manager/main.ts
index b946894ac3..fec5a8819e 100644
--- a/src/nni_manager/main.ts
+++ b/src/nni_manager/main.ts
@@ -26,7 +26,7 @@ import * as component from './common/component';
import { Database, DataStore } from './common/datastore';
import { setExperimentStartupInfo } from './common/experimentStartupInfo';
import { getLogger, Logger, logLevelNameMap } from './common/log';
-import { Manager } from './common/manager';
+import { Manager, ExperimentStartUpMode } from './common/manager';
import { TrainingService } from './common/trainingService';
import { getLogDir, mkDirP, parseArg, uniqueString } from './common/utils';
import { NNIDataStore } from './core/nniDataStore';
@@ -43,10 +43,10 @@ import {
function initStartupInfo(
startExpMode: string, resumeExperimentId: string, basePort: number,
- logDirectory: string, experimentLogLevel: string): void {
- const createNew: boolean = (startExpMode === 'new');
+ logDirectory: string, experimentLogLevel: string, readonly: boolean): void {
+ const createNew: boolean = (startExpMode === ExperimentStartUpMode.NEW);
const expId: string = createNew ? uniqueString(8) : resumeExperimentId;
- setExperimentStartupInfo(createNew, expId, basePort, logDirectory, experimentLogLevel);
+ setExperimentStartupInfo(createNew, expId, basePort, logDirectory, experimentLogLevel, readonly);
}
async function initContainer(platformMode: string): Promise {
@@ -108,15 +108,15 @@ if (!['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller'].includes(mode
}
const startMode: string = parseArg(['--start_mode', '-s']);
-if (!['new', 'resume'].includes(startMode)) {
+if (![ExperimentStartUpMode.NEW, ExperimentStartUpMode.RESUME].includes(startMode)) {
console.log(`FATAL: unknown start_mode: ${startMode}`);
usage();
process.exit(1);
}
const experimentId: string = parseArg(['--experiment_id', '-id']);
-if (startMode === 'resume' && experimentId.trim().length < 1) {
- console.log(`FATAL: cannot resume experiment, invalid experiment_id: ${experimentId}`);
+if ((startMode === ExperimentStartUpMode.RESUME) && experimentId.trim().length < 1) {
+ console.log(`FATAL: cannot resume the experiment, invalid experiment_id: ${experimentId}`);
usage();
process.exit(1);
}
@@ -133,7 +133,15 @@ if (logLevel.length > 0 && !logLevelNameMap.has(logLevel)) {
console.log(`FATAL: invalid log_level: ${logLevel}`);
}
-initStartupInfo(startMode, experimentId, port, logDir, logLevel);
+const readonlyArg: string = parseArg(['--readonly', '-r']);
+if (!('true' || 'false').includes(readonlyArg.toLowerCase())) {
+ console.log(`FATAL: readonly property should only be true or false`);
+ usage();
+ process.exit(1);
+}
+const readonly = readonlyArg.toLowerCase() == 'true' ? true : false;
+
+initStartupInfo(startMode, experimentId, port, logDir, logLevel, readonly);
mkDirP(getLogDir())
.then(async () => {
diff --git a/src/nni_manager/rest_server/restHandler.ts b/src/nni_manager/rest_server/restHandler.ts
index 7c11b15e72..83c95a2987 100644
--- a/src/nni_manager/rest_server/restHandler.ts
+++ b/src/nni_manager/rest_server/restHandler.ts
@@ -25,9 +25,9 @@ import * as path from 'path';
import * as component from '../common/component';
import { DataStore, MetricDataRecord, TrialJobInfo } from '../common/datastore';
import { NNIError, NNIErrorNames } from '../common/errors';
-import { isNewExperiment } from '../common/experimentStartupInfo';
+import { isNewExperiment, isReadonly } from '../common/experimentStartupInfo';
import { getLogger, Logger } from '../common/log';
-import { ExperimentProfile, Manager, TrialJobStatistics} from '../common/manager';
+import { ExperimentProfile, Manager, TrialJobStatistics, ExperimentStartUpMode } from '../common/manager';
import { ValidationSchemas } from './restValidationSchemas';
import { NNIRestServer } from './nniRestServer';
import { getVersion } from '../common/utils';
@@ -72,6 +72,8 @@ class NNIRestHandler {
this.addTrialJob(router);
this.cancelTrialJob(router);
this.getMetricData(router);
+ this.getMetricDataByRange(router);
+ this.getLatestMetricData(router);
this.exportData(router);
// Express-joi-validator configuration
@@ -86,11 +88,11 @@ class NNIRestHandler {
return router;
}
- private handle_error(err: Error, res: Response, isFatal: boolean = false): void {
+ private handle_error(err: Error, res: Response, isFatal: boolean = false, errorCode: number = 500): void {
if (err instanceof NNIError && err.name === NNIErrorNames.NOT_FOUND) {
res.status(404);
} else {
- res.status(500);
+ res.status(errorCode);
}
res.send({
error: err.message
@@ -169,13 +171,13 @@ class NNIRestHandler {
this.handle_error(err, res);
});
} else {
- this.nniManager.resumeExperiment().then(() => {
+ this.nniManager.resumeExperiment(isReadonly()).then(() => {
res.send();
}).catch((err: Error) => {
// Resume experiment is a step of initialization, so any exception thrown is a fatal
this.handle_error(err, res);
});
- }
+ }
});
}
@@ -193,18 +195,18 @@ class NNIRestHandler {
router.put(
'/experiment/cluster-metadata', expressJoi(ValidationSchemas.SETCLUSTERMETADATA),
async (req: Request, res: Response) => {
- // tslint:disable-next-line:no-any
- const metadata: any = req.body;
- const keys: string[] = Object.keys(metadata);
- try {
- for (const key of keys) {
- await this.nniManager.setClusterMetadata(key, JSON.stringify(metadata[key]));
+ // tslint:disable-next-line:no-any
+ const metadata: any = req.body;
+ const keys: string[] = Object.keys(metadata);
+ try {
+ for (const key of keys) {
+ await this.nniManager.setClusterMetadata(key, JSON.stringify(metadata[key]));
+ }
+ res.send();
+ } catch (err) {
+ // setClusterMetata is a step of initialization, so any exception thrown is a fatal
+ this.handle_error(NNIError.FromError(err), res, true);
}
- res.send();
- } catch (err) {
- // setClusterMetata is a step of initialization, so any exception thrown is a fatal
- this.handle_error(NNIError.FromError(err), res, true);
- }
});
}
@@ -262,6 +264,28 @@ class NNIRestHandler {
});
}
+ private getMetricDataByRange(router: Router): void {
+ router.get('/metric-data-range/:min_seq_id/:max_seq_id', async (req: Request, res: Response) => {
+ const minSeqId = Number(req.params.min_seq_id);
+ const maxSeqId = Number(req.params.max_seq_id);
+ this.nniManager.getMetricDataByRange(minSeqId, maxSeqId).then((metricsData: MetricDataRecord[]) => {
+ res.send(metricsData);
+ }).catch((err: Error) => {
+ this.handle_error(err, res);
+ });
+ });
+ }
+
+ private getLatestMetricData(router: Router): void {
+ router.get('/metric-data-latest/', async (req: Request, res: Response) => {
+ this.nniManager.getLatestMetricData().then((metricsData: MetricDataRecord[]) => {
+ res.send(metricsData);
+ }).catch((err: Error) => {
+ this.handle_error(err, res);
+ });
+ });
+ }
+
private exportData(router: Router): void {
router.get('/export-data', (req: Request, res: Response) => {
this.nniManager.exportData().then((exportedData: string) => {
diff --git a/src/nni_manager/rest_server/restValidationSchemas.ts b/src/nni_manager/rest_server/restValidationSchemas.ts
index 56b1b2c633..99bbe4bb96 100644
--- a/src/nni_manager/rest_server/restValidationSchemas.ts
+++ b/src/nni_manager/rest_server/restValidationSchemas.ts
@@ -170,18 +170,18 @@ export namespace ValidationSchemas {
classFileName: joi.string(),
className: joi.string(),
classArgs: joi.any(),
- gpuNum: joi.number().min(0),
- checkpointDir: joi.string().allow('')
+ checkpointDir: joi.string().allow(''),
+ gpuIndices: joi.string()
}),
tuner: joi.object({
- builtinTunerName: joi.string().valid('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner', 'GridSearch', 'NetworkMorphism', 'MetisTuner', 'GPTuner'),
+ builtinTunerName: joi.string().valid('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner', 'GridSearch', 'NetworkMorphism', 'MetisTuner', 'GPTuner', 'PPOTuner'),
codeDir: joi.string(),
classFileName: joi.string(),
className: joi.string(),
classArgs: joi.any(),
- gpuNum: joi.number().min(0),
checkpointDir: joi.string().allow(''),
- includeIntermediateResults: joi.boolean()
+ includeIntermediateResults: joi.boolean(),
+ gpuIndices: joi.string()
}),
assessor: joi.object({
builtinAssessorName: joi.string().valid('Medianstop', 'Curvefitting'),
@@ -189,7 +189,6 @@ export namespace ValidationSchemas {
classFileName: joi.string(),
className: joi.string(),
classArgs: joi.any(),
- gpuNum: joi.number().min(0),
checkpointDir: joi.string().allow('')
}),
clusterMetaData: joi.array().items(joi.object({
@@ -210,7 +209,7 @@ export namespace ValidationSchemas {
startTime: joi.number(),
endTime: joi.number(),
logDir: joi.string(),
- maxSequenceId: joi.number()
+ nextSequenceId: joi.number()
}
};
}
diff --git a/src/nni_manager/rest_server/test/mockedNNIManager.ts b/src/nni_manager/rest_server/test/mockedNNIManager.ts
index 299c473aa6..3c4a502ec8 100644
--- a/src/nni_manager/rest_server/test/mockedNNIManager.ts
+++ b/src/nni_manager/rest_server/test/mockedNNIManager.ts
@@ -85,9 +85,9 @@ export class MockedNNIManager extends Manager {
// tslint:disable-next-line:no-http-string
url: 'http://test',
workingDirectory: '/tmp/mocked',
- sequenceId: 0,
form: {
- jobType: 'TRIAL'
+ sequenceId: 0,
+ hyperParameters: { value: '', index: 0 }
}
};
deferred.resolve(jobDetail);
@@ -129,6 +129,12 @@ export class MockedNNIManager extends Manager {
public getMetricData(trialJobId: string, metricType: MetricType): Promise {
throw new MethodNotImplementedError();
}
+ public getMetricDataByRange(minSeqId: number, maxSeqId: number): Promise {
+ throw new MethodNotImplementedError();
+ }
+ public getLatestMetricData(): Promise {
+ throw new MethodNotImplementedError();
+ }
public getExperimentProfile(): Promise {
const profile: ExperimentProfile = {
params: {
@@ -148,7 +154,7 @@ export class MockedNNIManager extends Manager {
execDuration: 0,
startTime: Date.now(),
endTime: Date.now(),
- maxSequenceId: 0,
+ nextSequenceId: 0,
revision: 0
};
diff --git a/src/nni_manager/training_service/kubernetes/frameworkcontroller/frameworkcontrollerTrainingService.ts b/src/nni_manager/training_service/kubernetes/frameworkcontroller/frameworkcontrollerTrainingService.ts
index 51d56d5b7c..f54bd11e9e 100644
--- a/src/nni_manager/training_service/kubernetes/frameworkcontroller/frameworkcontrollerTrainingService.ts
+++ b/src/nni_manager/training_service/kubernetes/frameworkcontroller/frameworkcontrollerTrainingService.ts
@@ -25,7 +25,7 @@ import * as path from 'path';
import * as component from '../../../common/component';
import { getExperimentId } from '../../../common/experimentStartupInfo';
import {
- JobApplicationForm, NNIManagerIpConfig, TrialJobApplicationForm, TrialJobDetail, TrialJobStatus
+ NNIManagerIpConfig, TrialJobApplicationForm, TrialJobDetail, TrialJobStatus
} from '../../../common/trainingService';
import { delay, generateParamFileName, getExperimentRootDir, uniqueString } from '../../../common/utils';
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../../common/containerJobData';
@@ -55,7 +55,6 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
super();
this.fcJobInfoCollector = new FrameworkControllerJobInfoCollector(this.trialJobsMap);
this.experimentId = getExperimentId();
- this.nextTrialSequenceId = -1;
}
public async run(): Promise {
@@ -77,7 +76,7 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
}
}
- public async submitTrialJob(form: JobApplicationForm): Promise {
+ public async submitTrialJob(form: TrialJobApplicationForm): Promise {
if (this.fcClusterConfig === undefined) {
throw new Error('frameworkcontrollerClusterConfig is not initialized');
}
@@ -91,14 +90,13 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
}
const trialJobId: string = uniqueString(5);
- const curTrialSequenceId: number = this.generateSequenceId();
// Set trial's NFS working folder
const trialWorkingFolder: string = path.join(this.CONTAINER_MOUNT_PATH, 'nni', getExperimentId(), trialJobId);
const trialLocalTempFolder: string = path.join(getExperimentRootDir(), 'trials-local', trialJobId);
const frameworkcontrollerJobName: string = `nniexp${this.experimentId}trial${trialJobId}`.toLowerCase();
//Generate the port used for taskRole
this.generateContainerPort();
- await this.prepareRunScript(trialLocalTempFolder, curTrialSequenceId, trialJobId, trialWorkingFolder, form);
+ await this.prepareRunScript(trialLocalTempFolder, trialJobId, trialWorkingFolder, form);
//upload code files
const trialJobOutputUrl: string = await this.uploadCodeFiles(trialJobId, trialLocalTempFolder);
@@ -113,7 +111,6 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
trialWorkingFolder,
form,
frameworkcontrollerJobName,
- curTrialSequenceId,
trialJobOutputUrl
);
@@ -248,8 +245,8 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
return `${portScript} . /mnt/frameworkbarrier/injector.sh && ${command}`;
}
- private async prepareRunScript(trialLocalTempFolder: string, curTrialSequenceId: number, trialJobId: string,
- trialWorkingFolder: string, form: JobApplicationForm): Promise {
+ private async prepareRunScript(trialLocalTempFolder: string, trialJobId: string,
+ trialWorkingFolder: string, form: TrialJobApplicationForm): Promise {
if (this.fcTrialConfig === undefined) {
throw new Error('frameworkcontroller trial config is not initialized');
}
@@ -264,16 +261,16 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
for (const taskRole of this.fcTrialConfig.taskRoles) {
const runScriptContent: string =
await this.generateRunScript('frameworkcontroller', trialJobId, trialWorkingFolder,
- this.generateCommandScript(taskRole.command), curTrialSequenceId.toString(),
+ this.generateCommandScript(taskRole.command), form.sequenceId.toString(),
taskRole.name, taskRole.gpuNum);
await fs.promises.writeFile(path.join(trialLocalTempFolder, `run_${taskRole.name}.sh`), runScriptContent, { encoding: 'utf8' });
}
// Write file content ( parameter.cfg ) to local tmp folders
const trialForm : TrialJobApplicationForm = (form);
- if (trialForm !== undefined && trialForm.hyperParameters !== undefined) {
- await fs.promises.writeFile(path.join(trialLocalTempFolder, generateParamFileName(trialForm.hyperParameters)),
- trialForm.hyperParameters.value, { encoding: 'utf8' });
+ if (form !== undefined) {
+ await fs.promises.writeFile(path.join(trialLocalTempFolder, generateParamFileName(form.hyperParameters)),
+ form.hyperParameters.value, { encoding: 'utf8' });
}
}
diff --git a/src/nni_manager/training_service/kubernetes/kubeflow/kubeflowTrainingService.ts b/src/nni_manager/training_service/kubernetes/kubeflow/kubeflowTrainingService.ts
index e70246176a..de61deb3ef 100644
--- a/src/nni_manager/training_service/kubernetes/kubeflow/kubeflowTrainingService.ts
+++ b/src/nni_manager/training_service/kubernetes/kubeflow/kubeflowTrainingService.ts
@@ -27,7 +27,7 @@ import * as component from '../../../common/component';
import { getExperimentId } from '../../../common/experimentStartupInfo';
import {
- JobApplicationForm, NNIManagerIpConfig, TrialJobApplicationForm, TrialJobDetail, TrialJobStatus
+ NNIManagerIpConfig, TrialJobApplicationForm, TrialJobDetail, TrialJobStatus
} from '../../../common/trainingService';
import { delay, generateParamFileName, getExperimentRootDir, uniqueString } from '../../../common/utils';
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../../common/containerJobData';
@@ -59,7 +59,6 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
super();
this.kubeflowJobInfoCollector = new KubeflowJobInfoCollector(this.trialJobsMap);
this.experimentId = getExperimentId();
- this.nextTrialSequenceId = -1;
this.log.info('Construct Kubeflow training service.');
}
@@ -84,7 +83,7 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
this.log.info('Kubeflow training service exit.');
}
- public async submitTrialJob(form: JobApplicationForm): Promise {
+ public async submitTrialJob(form: TrialJobApplicationForm): Promise {
if (this.kubernetesCRDClient === undefined) {
throw new Error('Kubeflow job operator client is undefined');
}
@@ -96,10 +95,9 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
const trialJobId: string = uniqueString(5);
const trialWorkingFolder: string = path.join(this.CONTAINER_MOUNT_PATH, 'nni', getExperimentId(), trialJobId);
const kubeflowJobName: string = `nni-exp-${this.experimentId}-trial-${trialJobId}`.toLowerCase();
- const curTrialSequenceId: number = this.generateSequenceId();
const trialLocalTempFolder: string = path.join(getExperimentRootDir(), 'trials-local', trialJobId);
//prepare the runscript
- await this.prepareRunScript(trialLocalTempFolder, trialJobId, trialWorkingFolder, curTrialSequenceId, form);
+ await this.prepareRunScript(trialLocalTempFolder, trialJobId, trialWorkingFolder, form);
//upload files to sotrage
const trialJobOutputUrl: string = await this.uploadCodeFiles(trialJobId, trialLocalTempFolder);
let initStatus: TrialJobStatus = 'WAITING';
@@ -113,7 +111,6 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
trialWorkingFolder,
form,
kubeflowJobName,
- curTrialSequenceId,
trialJobOutputUrl
);
@@ -236,8 +233,8 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
return Promise.resolve(trialJobOutputUrl);
}
- private async prepareRunScript(trialLocalTempFolder: string, trialJobId: string, trialWorkingFolder: string, curTrialSequenceId: number,
- form: JobApplicationForm): Promise {
+ private async prepareRunScript(trialLocalTempFolder: string, trialJobId: string, trialWorkingFolder: string,
+ form: TrialJobApplicationForm): Promise {
if (this.kubeflowClusterConfig === undefined) {
throw new Error('Kubeflow Cluster config is not initialized');
}
@@ -262,7 +259,7 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
if (kubeflowTrialConfig.worker !== undefined) {
const workerRunScriptContent: string = await this.generateRunScript('kubeflow', trialJobId, trialWorkingFolder,
kubeflowTrialConfig.worker.command,
- curTrialSequenceId.toString(), 'worker',
+ form.sequenceId.toString(), 'worker',
kubeflowTrialConfig.worker.gpuNum);
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run_worker.sh'), workerRunScriptContent, { encoding: 'utf8' });
}
@@ -272,7 +269,7 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
if (tensorflowTrialConfig.ps !== undefined) {
const psRunScriptContent: string = await this.generateRunScript('kubeflow', trialJobId, trialWorkingFolder,
tensorflowTrialConfig.ps.command,
- curTrialSequenceId.toString(),
+ form.sequenceId.toString(),
'ps', tensorflowTrialConfig.ps.gpuNum);
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run_ps.sh'), psRunScriptContent, { encoding: 'utf8' });
}
@@ -281,16 +278,15 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
if (pytorchTrialConfig.master !== undefined) {
const masterRunScriptContent: string = await this.generateRunScript('kubeflow', trialJobId, trialWorkingFolder,
pytorchTrialConfig.master.command,
- curTrialSequenceId.toString(), 'master',
+ form.sequenceId.toString(), 'master',
pytorchTrialConfig.master.gpuNum);
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run_master.sh'), masterRunScriptContent, { encoding: 'utf8' });
}
}
// Write file content ( parameter.cfg ) to local tmp folders
- const trialForm : TrialJobApplicationForm = (form);
- if (trialForm !== undefined && trialForm.hyperParameters !== undefined) {
- await fs.promises.writeFile(path.join(trialLocalTempFolder, generateParamFileName(trialForm.hyperParameters)),
- trialForm.hyperParameters.value, { encoding: 'utf8' });
+ if (form !== undefined) {
+ await fs.promises.writeFile(path.join(trialLocalTempFolder, generateParamFileName(form.hyperParameters)),
+ form.hyperParameters.value, { encoding: 'utf8' });
}
}
diff --git a/src/nni_manager/training_service/kubernetes/kubernetesData.ts b/src/nni_manager/training_service/kubernetes/kubernetesData.ts
index b52e9a3049..f49f67eee0 100644
--- a/src/nni_manager/training_service/kubernetes/kubernetesData.ts
+++ b/src/nni_manager/training_service/kubernetes/kubernetesData.ts
@@ -19,7 +19,7 @@
'use strict';
-import { JobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
+import { TrialJobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
/**
* KubeflowTrialJobDetail
@@ -33,21 +33,19 @@ export class KubernetesTrialJobDetail implements TrialJobDetail {
public tags?: string[];
public url?: string;
public workingDirectory: string;
- public form: JobApplicationForm;
+ public form: TrialJobApplicationForm;
public kubernetesJobName: string;
- public sequenceId: number;
public queryJobFailedCount: number;
constructor(id: string, status: TrialJobStatus, submitTime: number,
- workingDirectory: string, form: JobApplicationForm,
- kubernetesJobName: string, sequenceId: number, url: string) {
+ workingDirectory: string, form: TrialJobApplicationForm,
+ kubernetesJobName: string, url: string) {
this.id = id;
this.status = status;
this.submitTime = submitTime;
this.workingDirectory = workingDirectory;
this.form = form;
this.kubernetesJobName = kubernetesJobName;
- this.sequenceId = sequenceId;
this.tags = [];
this.queryJobFailedCount = 0;
this.url = url;
diff --git a/src/nni_manager/training_service/kubernetes/kubernetesTrainingService.ts b/src/nni_manager/training_service/kubernetes/kubernetesTrainingService.ts
index 6a1df6e0f2..62e4916599 100644
--- a/src/nni_manager/training_service/kubernetes/kubernetesTrainingService.ts
+++ b/src/nni_manager/training_service/kubernetes/kubernetesTrainingService.ts
@@ -26,7 +26,7 @@ import * as azureStorage from 'azure-storage';
import { EventEmitter } from 'events';
import { Base64 } from 'js-base64';
import { String } from 'typescript-string-operations';
-import { getExperimentId, getInitTrialSequenceId } from '../../common/experimentStartupInfo';
+import { getExperimentId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log';
import {
NNIManagerIpConfig, TrialJobDetail, TrialJobMetric
@@ -53,7 +53,6 @@ abstract class KubernetesTrainingService {
protected readonly trialLocalNFSTempFolder: string;
protected stopping: boolean = false;
protected experimentId! : string;
- protected nextTrialSequenceId: number;
protected kubernetesRestServerPort?: number;
protected readonly CONTAINER_MOUNT_PATH: string;
protected azureStorageClient?: azureStorage.FileService;
@@ -74,7 +73,6 @@ abstract class KubernetesTrainingService {
this.trialJobsMap = new Map();
this.trialLocalNFSTempFolder = path.join(getExperimentRootDir(), 'trials-nfs-tmp');
this.experimentId = getExperimentId();
- this.nextTrialSequenceId = -1;
this.CONTAINER_MOUNT_PATH = '/tmp/mount';
this.genericK8sClient = new GeneralK8sClient();
this.logCollection = 'none';
@@ -93,9 +91,7 @@ abstract class KubernetesTrainingService {
const jobs: TrialJobDetail[] = [];
for (const [key, value] of this.trialJobsMap) {
- if (value.form.jobType === 'TRIAL') {
- jobs.push(await this.getTrialJob(key));
- }
+ jobs.push(await this.getTrialJob(key));
}
return Promise.resolve(jobs);
@@ -222,14 +218,6 @@ abstract class KubernetesTrainingService {
return Promise.resolve();
}
- protected generateSequenceId(): number {
- if (this.nextTrialSequenceId === -1) {
- this.nextTrialSequenceId = getInitTrialSequenceId();
- }
-
- return this.nextTrialSequenceId++;
- }
-
// tslint:disable: no-unsafe-any no-any
protected async createAzureStorage(vaultName: string, valutKeyName: string, accountName: string, azureShare: string): Promise {
try {
diff --git a/src/nni_manager/training_service/local/localTrainingService.ts b/src/nni_manager/training_service/local/localTrainingService.ts
index 88e006a3f9..1a7c70d3a1 100644
--- a/src/nni_manager/training_service/local/localTrainingService.ts
+++ b/src/nni_manager/training_service/local/localTrainingService.ts
@@ -26,10 +26,10 @@ import * as path from 'path';
import * as ts from 'tail-stream';
import * as tkill from 'tree-kill';
import { NNIError, NNIErrorNames } from '../../common/errors';
-import { getExperimentId, getInitTrialSequenceId } from '../../common/experimentStartupInfo';
+import { getExperimentId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log';
import {
- HostJobApplicationForm, HyperParameters, JobApplicationForm, TrainingService, TrialJobApplicationForm,
+ HyperParameters, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric, TrialJobStatus
} from '../../common/trainingService';
import {
@@ -76,21 +76,19 @@ class LocalTrialJobDetail implements TrialJobDetail {
public tags?: string[];
public url?: string;
public workingDirectory: string;
- public form: JobApplicationForm;
- public sequenceId: number;
+ public form: TrialJobApplicationForm;
public pid?: number;
public gpuIndices?: number[];
constructor(
id: string, status: TrialJobStatus, submitTime: number,
- workingDirectory: string, form: JobApplicationForm, sequenceId: number) {
+ workingDirectory: string, form: TrialJobApplicationForm) {
this.id = id;
this.status = status;
this.submitTime = submitTime;
this.workingDirectory = workingDirectory;
this.form = form;
this.url = `file://localhost:${workingDirectory}`;
- this.sequenceId = sequenceId;
this.gpuIndices = [];
}
}
@@ -125,7 +123,6 @@ class LocalTrainingService implements TrainingService {
private initialized: boolean;
private stopping: boolean;
private rootDir!: string;
- private trialSequenceId: number;
private readonly experimentId! : string;
private gpuScheduler!: GPUScheduler;
private readonly occupiedGpuIndexNumMap: Map;
@@ -145,7 +142,6 @@ class LocalTrainingService implements TrainingService {
this.initialized = false;
this.stopping = false;
this.log = getLogger();
- this.trialSequenceId = -1;
this.experimentId = getExperimentId();
this.jobStreamMap = new Map();
this.log.info('Construct local machine training service.');
@@ -169,9 +165,7 @@ class LocalTrainingService implements TrainingService {
const jobs: TrialJobDetail[] = [];
for (const key of this.jobMap.keys()) {
const trialJob: TrialJobDetail = await this.getTrialJob(key);
- if (trialJob.form.jobType === 'TRIAL') {
- jobs.push(trialJob);
- }
+ jobs.push(trialJob);
}
return jobs;
@@ -182,9 +176,6 @@ class LocalTrainingService implements TrainingService {
if (trialJob === undefined) {
throw new NNIError(NNIErrorNames.NOT_FOUND, 'Trial job not found');
}
- if (trialJob.form.jobType === 'HOST') {
- return this.getHostJob(trialJobId);
- }
if (trialJob.status === 'RUNNING') {
const alive: boolean = await isAlive(trialJob.pid);
if (!alive) {
@@ -219,28 +210,21 @@ class LocalTrainingService implements TrainingService {
this.eventEmitter.off('metric', listener);
}
- public submitTrialJob(form: JobApplicationForm): Promise {
- if (form.jobType === 'HOST') {
- return this.runHostJob(form);
- } else if (form.jobType === 'TRIAL') {
- const trialJobId: string = uniqueString(5);
- const trialJobDetail: LocalTrialJobDetail = new LocalTrialJobDetail(
- trialJobId,
- 'WAITING',
- Date.now(),
- path.join(this.rootDir, 'trials', trialJobId),
- form,
- this.generateSequenceId()
- );
- this.jobQueue.push(trialJobId);
- this.jobMap.set(trialJobId, trialJobDetail);
-
- this.log.debug(`submitTrialJob: return: ${JSON.stringify(trialJobDetail)} `);
-
- return Promise.resolve(trialJobDetail);
- } else {
- return Promise.reject(new Error(`Job form not supported: ${JSON.stringify(form)}`));
- }
+ public submitTrialJob(form: TrialJobApplicationForm): Promise {
+ const trialJobId: string = uniqueString(5);
+ const trialJobDetail: LocalTrialJobDetail = new LocalTrialJobDetail(
+ trialJobId,
+ 'WAITING',
+ Date.now(),
+ path.join(this.rootDir, 'trials', trialJobId),
+ form
+ );
+ this.jobQueue.push(trialJobId);
+ this.jobMap.set(trialJobId, trialJobDetail);
+
+ this.log.debug(`submitTrialJob: return: ${JSON.stringify(trialJobDetail)} `);
+
+ return Promise.resolve(trialJobDetail);
}
/**
@@ -248,16 +232,12 @@ class LocalTrainingService implements TrainingService {
* @param trialJobId trial job id
* @param form job application form
*/
- public async updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise {
+ public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise {
const trialJobDetail: undefined | TrialJobDetail = this.jobMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
}
- if (form.jobType === 'TRIAL') {
- await this.writeParameterFile(trialJobDetail.workingDirectory, (form).hyperParameters);
- } else {
- throw new Error(`updateTrialJob failed: jobType ${form.jobType} not supported.`);
- }
+ await this.writeParameterFile(trialJobDetail.workingDirectory, form.hyperParameters);
return trialJobDetail;
}
@@ -279,13 +259,7 @@ class LocalTrainingService implements TrainingService {
return Promise.resolve();
}
- if (trialJob.form.jobType === 'TRIAL') {
- tkill(trialJob.pid, 'SIGKILL');
- } else if (trialJob.form.jobType === 'HOST') {
- await cpp.exec(`pkill -9 -P ${trialJob.pid}`);
- } else {
- throw new Error(`Job type not supported: ${trialJob.form.jobType}`);
- }
+ tkill(trialJob.pid, 'SIGKILL');
this.setTrialJobStatus(trialJob, getJobCancelStatus(isEarlyStopped));
return Promise.resolve();
@@ -409,7 +383,7 @@ class LocalTrainingService implements TrainingService {
{ key: 'NNI_SYS_DIR', value: trialJobDetail.workingDirectory },
{ key: 'NNI_TRIAL_JOB_ID', value: trialJobDetail.id },
{ key: 'NNI_OUTPUT_DIR', value: trialJobDetail.workingDirectory },
- { key: 'NNI_TRIAL_SEQ_ID', value: trialJobDetail.sequenceId.toString() },
+ { key: 'NNI_TRIAL_SEQ_ID', value: trialJobDetail.form.sequenceId.toString() },
{ key: 'MULTI_PHASE', value: this.isMultiPhase.toString() }
];
if (gpuNum !== undefined) {
@@ -562,7 +536,7 @@ class LocalTrainingService implements TrainingService {
const scriptName: string = getScriptName('run');
await fs.promises.writeFile(path.join(trialJobDetail.workingDirectory, scriptName),
runScriptContent.join(getNewLine()), { encoding: 'utf8', mode: 0o777 });
- await this.writeParameterFile(trialJobDetail.workingDirectory, (trialJobDetail.form).hyperParameters);
+ await this.writeParameterFile(trialJobDetail.workingDirectory, trialJobDetail.form.hyperParameters);
const trialJobProcess: cp.ChildProcess = runScript(path.join(trialJobDetail.workingDirectory, scriptName));
this.setTrialJobStatus(trialJobDetail, 'RUNNING');
trialJobDetail.startTime = Date.now();
@@ -589,60 +563,10 @@ class LocalTrainingService implements TrainingService {
this.jobStreamMap.set(trialJobDetail.id, stream);
}
- private async runHostJob(form: HostJobApplicationForm): Promise {
- const jobId: string = uniqueString(5);
- const workDir: string = path.join(this.rootDir, 'hostjobs', jobId);
- await cpp.exec(`mkdir -p ${workDir}`);
- const wrappedCmd: string = `cd ${workDir} && ${form.cmd}>stdout 2>stderr`;
- this.log.debug(`runHostJob: command: ${wrappedCmd}`);
- const process: cp.ChildProcess = cp.exec(wrappedCmd);
- const jobDetail: LocalTrialJobDetail = {
- id: jobId,
- status: 'RUNNING',
- submitTime: Date.now(),
- workingDirectory: workDir,
- form: form,
- sequenceId: this.generateSequenceId(),
- pid: process.pid
- };
- this.jobMap.set(jobId, jobDetail);
- this.log.debug(`runHostJob: return: ${JSON.stringify(jobDetail)} `);
-
- return jobDetail;
- }
-
- private async getHostJob(jobId: string): Promise {
- const jobDetail: LocalTrialJobDetail | undefined = this.jobMap.get(jobId);
- if (jobDetail === undefined) {
- throw new NNIError(NNIErrorNames.NOT_FOUND, `Host Job not found: ${jobId}`);
- }
- try {
- await cpp.exec(`kill -0 ${jobDetail.pid}`);
-
- return jobDetail;
- } catch (error) {
- if (error instanceof Error) {
- this.log.debug(`getHostJob: error: ${error.message}`);
- this.jobMap.delete(jobId);
- throw new NNIError(NNIErrorNames.NOT_FOUND, `Host Job not found: ${error.message}`);
- } else {
- throw error;
- }
- }
- }
-
private async writeParameterFile(directory: string, hyperParameters: HyperParameters): Promise {
const filepath: string = path.join(directory, generateParamFileName(hyperParameters));
await fs.promises.writeFile(filepath, hyperParameters.value, { encoding: 'utf8' });
}
-
- private generateSequenceId(): number {
- if (this.trialSequenceId === -1) {
- this.trialSequenceId = getInitTrialSequenceId();
- }
-
- return this.trialSequenceId++;
- }
}
export { LocalTrainingService };
diff --git a/src/nni_manager/training_service/pai/paiData.ts b/src/nni_manager/training_service/pai/paiData.ts
index 8ac4b77ed1..d10be902e1 100644
--- a/src/nni_manager/training_service/pai/paiData.ts
+++ b/src/nni_manager/training_service/pai/paiData.ts
@@ -19,7 +19,7 @@
'use strict';
-import { JobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
+import { TrialJobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
/**
* PAI trial job detail
@@ -34,20 +34,18 @@ export class PAITrialJobDetail implements TrialJobDetail {
public tags?: string[];
public url?: string;
public workingDirectory: string;
- public form: JobApplicationForm;
- public sequenceId: number;
+ public form: TrialJobApplicationForm;
public hdfsLogPath: string;
public isEarlyStopped?: boolean;
constructor(id: string, status: TrialJobStatus, paiJobName : string,
- submitTime: number, workingDirectory: string, form: JobApplicationForm, sequenceId: number, hdfsLogPath: string) {
+ submitTime: number, workingDirectory: string, form: TrialJobApplicationForm, hdfsLogPath: string) {
this.id = id;
this.status = status;
this.paiJobName = paiJobName;
this.submitTime = submitTime;
this.workingDirectory = workingDirectory;
this.form = form;
- this.sequenceId = sequenceId;
this.tags = [];
this.hdfsLogPath = hdfsLogPath;
}
diff --git a/src/nni_manager/training_service/pai/paiTrainingService.ts b/src/nni_manager/training_service/pai/paiTrainingService.ts
index ff742f0fc0..4e949e7708 100644
--- a/src/nni_manager/training_service/pai/paiTrainingService.ts
+++ b/src/nni_manager/training_service/pai/paiTrainingService.ts
@@ -30,10 +30,10 @@ import { EventEmitter } from 'events';
import { Deferred } from 'ts-deferred';
import { String } from 'typescript-string-operations';
import { MethodNotImplementedError } from '../../common/errors';
-import { getExperimentId, getInitTrialSequenceId } from '../../common/experimentStartupInfo';
+import { getExperimentId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log';
import {
- HyperParameters, JobApplicationForm, NNIManagerIpConfig, TrainingService,
+ HyperParameters, NNIManagerIpConfig, TrainingService,
TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
} from '../../common/trainingService';
import { delay, generateParamFileName,
@@ -70,7 +70,6 @@ class PAITrainingService implements TrainingService {
private readonly paiTokenUpdateInterval: number;
private readonly experimentId! : string;
private readonly paiJobCollector : PAIJobInfoCollector;
- private nextTrialSequenceId: number;
private paiRestServerPort?: number;
private nniManagerIpConfig?: NNIManagerIpConfig;
private copyExpCodeDirPromise?: Promise;
@@ -90,7 +89,6 @@ class PAITrainingService implements TrainingService {
this.expRootDir = path.join('/nni', 'experiments', getExperimentId());
this.experimentId = getExperimentId();
this.paiJobCollector = new PAIJobInfoCollector(this.trialJobsMap);
- this.nextTrialSequenceId = -1;
this.paiTokenUpdateInterval = 7200000; //2hours
this.logCollection = 'none';
this.log.info('Construct OpenPAI training service.');
@@ -112,9 +110,7 @@ class PAITrainingService implements TrainingService {
const jobs: TrialJobDetail[] = [];
for (const [key, value] of this.trialJobsMap) {
- if (value.form.jobType === 'TRIAL') {
- jobs.push(await this.getTrialJob(key));
- }
+ jobs.push(await this.getTrialJob(key));
}
return Promise.resolve(jobs);
@@ -142,7 +138,7 @@ class PAITrainingService implements TrainingService {
this.metricsEmitter.off('metric', listener);
}
- public async submitTrialJob(form: JobApplicationForm): Promise {
+ public async submitTrialJob(form: TrialJobApplicationForm): Promise {
if (this.paiClusterConfig === undefined) {
throw new Error(`paiClusterConfig not initialized!`);
}
@@ -151,7 +147,6 @@ class PAITrainingService implements TrainingService {
this.log.info(`submitTrialJob: form: ${JSON.stringify(form)}`);
const trialJobId: string = uniqueString(5);
- const trialSequenceId: number = this.generateSequenceId();
//TODO: use HDFS working folder instead
const trialWorkingFolder: string = path.join(this.expRootDir, 'trials', trialJobId);
const paiJobName: string = `nni_exp_${this.experimentId}_trial_${trialJobId}`;
@@ -171,7 +166,6 @@ class PAITrainingService implements TrainingService {
Date.now(),
trialWorkingFolder,
form,
- trialSequenceId,
hdfsLogPath);
this.trialJobsMap.set(trialJobId, trialJobDetail);
@@ -181,16 +175,12 @@ class PAITrainingService implements TrainingService {
return deferred.promise;
}
- public async updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise {
+ public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise {
const trialJobDetail: undefined | TrialJobDetail = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
}
- if (form.jobType === 'TRIAL') {
- await this.writeParameterFile(trialJobId, (form).hyperParameters);
- } else {
- throw new Error(`updateTrialJob failed: jobType ${form.jobType} not supported.`);
- }
+ await this.writeParameterFile(trialJobId, form.hyperParameters);
return trialJobDetail;
}
@@ -397,11 +387,10 @@ class PAITrainingService implements TrainingService {
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'install_nni.sh'), runScriptContent, { encoding: 'utf8' });
// Write file content ( parameter.cfg ) to local tmp folders
- const trialForm : TrialJobApplicationForm = (trialJobDetail.form);
- if (trialForm !== undefined) {
+ if (trialJobDetail.form !== undefined) {
await fs.promises.writeFile(
- path.join(trialLocalTempFolder, generateParamFileName(trialForm.hyperParameters)),
- trialForm.hyperParameters.value, { encoding: 'utf8' }
+ path.join(trialLocalTempFolder, generateParamFileName(trialJobDetail.form.hyperParameters)),
+ trialJobDetail.form.hyperParameters.value, { encoding: 'utf8' }
);
}
const hdfsCodeDir: string = HDFSClientUtility.getHdfsTrialWorkDir(this.paiClusterConfig.userName, trialJobId);
@@ -416,7 +405,7 @@ class PAITrainingService implements TrainingService {
`$PWD/${trialJobId}/nnioutput`,
trialJobId,
this.experimentId,
- trialJobDetail.sequenceId,
+ trialJobDetail.form.sequenceId,
this.isMultiPhase,
this.paiTrialConfig.command,
nniManagerIp,
@@ -507,14 +496,6 @@ class PAITrainingService implements TrainingService {
return deferred.promise;
}
- private generateSequenceId(): number {
- if (this.nextTrialSequenceId === -1) {
- this.nextTrialSequenceId = getInitTrialSequenceId();
- }
-
- return this.nextTrialSequenceId++;
- }
-
private async statusCheckingLoop(): Promise {
while (!this.stopping) {
try {
diff --git a/src/nni_manager/training_service/remote_machine/remoteMachineData.ts b/src/nni_manager/training_service/remote_machine/remoteMachineData.ts
index 02ad95d627..7c2ba2f0e5 100644
--- a/src/nni_manager/training_service/remote_machine/remoteMachineData.ts
+++ b/src/nni_manager/training_service/remote_machine/remoteMachineData.ts
@@ -22,7 +22,7 @@
import * as fs from 'fs';
import { Client, ConnectConfig } from 'ssh2';
import { Deferred } from 'ts-deferred';
-import { JobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
+import { TrialJobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
import { GPUInfo, GPUSummary } from '../common/gpuData';
/**
@@ -82,20 +82,18 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail {
public tags?: string[];
public url?: string;
public workingDirectory: string;
- public form: JobApplicationForm;
- public sequenceId: number;
+ public form: TrialJobApplicationForm;
public rmMeta?: RemoteMachineMeta;
public isEarlyStopped?: boolean;
public gpuIndices: GPUInfo[];
constructor(id: string, status: TrialJobStatus, submitTime: number,
- workingDirectory: string, form: JobApplicationForm, sequenceId: number) {
+ workingDirectory: string, form: TrialJobApplicationForm) {
this.id = id;
this.status = status;
this.submitTime = submitTime;
this.workingDirectory = workingDirectory;
this.form = form;
- this.sequenceId = sequenceId;
this.tags = [];
this.gpuIndices = [];
}
diff --git a/src/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts b/src/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts
index 35631f1ce9..4733df6809 100644
--- a/src/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts
+++ b/src/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts
@@ -30,11 +30,11 @@ import { Deferred } from 'ts-deferred';
import { String } from 'typescript-string-operations';
import * as component from '../../common/component';
import { NNIError, NNIErrorNames } from '../../common/errors';
-import { getExperimentId, getInitTrialSequenceId } from '../../common/experimentStartupInfo';
+import { getExperimentId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log';
import { ObservableTimer } from '../../common/observableTimer';
import {
- HostJobApplicationForm, HyperParameters, JobApplicationForm, NNIManagerIpConfig, TrainingService, TrialJobApplicationForm,
+ HyperParameters, NNIManagerIpConfig, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric
} from '../../common/trainingService';
import {
@@ -172,9 +172,7 @@ class RemoteMachineTrainingService implements TrainingService {
const deferred: Deferred = new Deferred();
for (const [key, value] of this.trialJobsMap) {
- if (value.form.jobType === 'TRIAL') {
- jobs.push(await this.getTrialJob(key));
- }
+ jobs.push(await this.getTrialJob(key));
}
deferred.resolve(jobs);
@@ -228,33 +226,26 @@ class RemoteMachineTrainingService implements TrainingService {
* @param form trial job description form
*/
// tslint:disable-next-line:informative-docs
- public async submitTrialJob(form: JobApplicationForm): Promise {
+ public async submitTrialJob(form: TrialJobApplicationForm): Promise {
if (this.trialConfig === undefined) {
throw new Error('trial config is not initialized');
}
- if (form.jobType === 'HOST') {
- return this.runHostJob(form);
- } else if (form.jobType === 'TRIAL') {
- // Generate trial job id(random)
- const trialJobId: string = uniqueString(5);
- const trialWorkingFolder: string = unixPathJoin(this.remoteExpRootDir, 'trials', trialJobId);
+ // Generate trial job id(random)
+ const trialJobId: string = uniqueString(5);
+ const trialWorkingFolder: string = unixPathJoin(this.remoteExpRootDir, 'trials', trialJobId);
- const trialJobDetail: RemoteMachineTrialJobDetail = new RemoteMachineTrialJobDetail(
- trialJobId,
- 'WAITING',
- Date.now(),
- trialWorkingFolder,
- form,
- this.generateSequenceId()
- );
- this.jobQueue.push(trialJobId);
- this.trialJobsMap.set(trialJobId, trialJobDetail);
+ const trialJobDetail: RemoteMachineTrialJobDetail = new RemoteMachineTrialJobDetail(
+ trialJobId,
+ 'WAITING',
+ Date.now(),
+ trialWorkingFolder,
+ form
+ );
+ this.jobQueue.push(trialJobId);
+ this.trialJobsMap.set(trialJobId, trialJobDetail);
- return Promise.resolve(trialJobDetail);
- } else {
- return Promise.reject(new Error(`Job form not supported: ${JSON.stringify(form)}, jobType should be HOST or TRIAL.`));
- }
+ return Promise.resolve(trialJobDetail);
}
/**
@@ -262,20 +253,16 @@ class RemoteMachineTrainingService implements TrainingService {
* @param trialJobId trial job id
* @param form job application form
*/
- public async updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise {
+ public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise {
const trialJobDetail: undefined | TrialJobDetail = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
}
- if (form.jobType === 'TRIAL') {
- const rmMeta: RemoteMachineMeta | undefined = (trialJobDetail).rmMeta;
- if (rmMeta !== undefined) {
- await this.writeParameterFile(trialJobId, (form).hyperParameters, rmMeta);
- } else {
- throw new Error(`updateTrialJob failed: ${trialJobId} rmMeta not found`);
- }
+ const rmMeta: RemoteMachineMeta | undefined = (trialJobDetail).rmMeta;
+ if (rmMeta !== undefined) {
+ await this.writeParameterFile(trialJobId, form.hyperParameters, rmMeta);
} else {
- throw new Error(`updateTrialJob failed: jobType ${form.jobType} not supported.`);
+ throw new Error(`updateTrialJob failed: ${trialJobId} rmMeta not found`);
}
return trialJobDetail;
@@ -558,7 +545,7 @@ class RemoteMachineTrainingService implements TrainingService {
await this.allocateSSHClientForTrial(trialJobDetail);
await this.launchTrialOnScheduledMachine(
- trialJobId, trialWorkingFolder, trialJobDetail.form, rmScheduleInfo);
+ trialJobId, trialWorkingFolder, trialJobDetail.form, rmScheduleInfo);
trialJobDetail.status = 'RUNNING';
trialJobDetail.url = `file://${rmScheduleInfo.rmMeta.ip}:${trialWorkingFolder}`;
@@ -628,7 +615,7 @@ class RemoteMachineTrainingService implements TrainingService {
trialWorkingFolder,
trialJobId,
getExperimentId(),
- trialJobDetail.sequenceId.toString(),
+ trialJobDetail.form.sequenceId.toString(),
this.isMultiPhase,
unixPathJoin(trialWorkingFolder, '.nni', 'jobpid'),
command,
@@ -657,38 +644,6 @@ class RemoteMachineTrainingService implements TrainingService {
SSHClientUtility.remoteExeCommand(`bash ${unixPathJoin(trialWorkingFolder, 'run.sh')}`, sshClient);
}
- private async runHostJob(form: HostJobApplicationForm): Promise {
- const rmMeta: RemoteMachineMeta = this.getRmMetaByHost(form.host);
- const sshClientManager: SSHClientManager | undefined = this.machineSSHClientMap.get(rmMeta);
- if (sshClientManager === undefined) {
- throw new Error('sshClient not found.');
- }
- const sshClient: Client = sshClientManager.getFirstSSHClient();
- const jobId: string = uniqueString(5);
- const localDir: string = path.join(this.expRootDir, 'hostjobs-local', jobId);
- const remoteDir: string = this.getHostJobRemoteDir(jobId);
- await cpp.exec(`mkdir -p ${localDir}`);
- await SSHClientUtility.remoteExeCommand(`mkdir -p ${remoteDir}`, sshClient);
- const runScriptContent: string = String.Format(
- HOST_JOB_SHELL_FORMAT, remoteDir, path.join(remoteDir, 'jobpid'), form.cmd, path.join(remoteDir, 'code')
- );
- await fs.promises.writeFile(path.join(localDir, 'run.sh'), runScriptContent, { encoding: 'utf8' });
- await SSHClientUtility.copyFileToRemote(
- path.join(localDir, 'run.sh'), unixPathJoin(remoteDir, 'run.sh'), sshClient);
- // tslint:disable-next-line: no-floating-promises
- SSHClientUtility.remoteExeCommand(`bash ${unixPathJoin(remoteDir, 'run.sh')}`, sshClient);
-
- const jobDetail: RemoteMachineTrialJobDetail = new RemoteMachineTrialJobDetail(
- jobId, 'RUNNING', Date.now(), remoteDir, form, this.generateSequenceId()
- );
- jobDetail.rmMeta = rmMeta;
- jobDetail.startTime = Date.now();
- this.trialJobsMap.set(jobId, jobDetail);
- this.log.debug(`runHostJob: return: ${JSON.stringify(jobDetail)} `);
-
- return jobDetail;
- }
-
private getRmMetaByHost(host: string): RemoteMachineMeta {
for (const [rmMeta, client] of this.machineSSHClientMap.entries()) {
if (rmMeta.ip === host) {
@@ -765,13 +720,7 @@ class RemoteMachineTrainingService implements TrainingService {
}
let jobpidPath: string;
- if (trialJobDetail.form.jobType === 'TRIAL') {
- jobpidPath = unixPathJoin(trialJobDetail.workingDirectory, '.nni', 'jobpid');
- } else if (trialJobDetail.form.jobType === 'HOST') {
- jobpidPath = unixPathJoin(this.getHostJobRemoteDir(jobId), 'jobpid');
- } else {
- throw new Error(`Job type not supported: ${trialJobDetail.form.jobType}`);
- }
+ jobpidPath = unixPathJoin(trialJobDetail.workingDirectory, '.nni', 'jobpid');
return jobpidPath;
}
@@ -791,14 +740,6 @@ class RemoteMachineTrainingService implements TrainingService {
await SSHClientUtility.copyFileToRemote(localFilepath, unixPathJoin(trialWorkingFolder, fileName), sshClient);
}
-
- private generateSequenceId(): number {
- if (this.trialSequenceId === -1) {
- this.trialSequenceId = getInitTrialSequenceId();
- }
-
- return this.trialSequenceId++;
- }
}
export { RemoteMachineTrainingService };
diff --git a/src/nni_manager/training_service/test/localTrainingService.test.ts b/src/nni_manager/training_service/test/localTrainingService.test.ts
index d0d6daf3ad..2d95bb80cc 100644
--- a/src/nni_manager/training_service/test/localTrainingService.test.ts
+++ b/src/nni_manager/training_service/test/localTrainingService.test.ts
@@ -76,7 +76,7 @@ describe('Unit Test for LocalTrainingService', () => {
// submit job
const form: TrialJobApplicationForm = {
- jobType: 'TRIAL',
+ sequenceId: 0,
hyperParameters: {
value: 'mock hyperparameters',
index: 0
@@ -95,7 +95,7 @@ describe('Unit Test for LocalTrainingService', () => {
// submit job
const form: TrialJobApplicationForm = {
- jobType: 'TRIAL',
+ sequenceId: 0,
hyperParameters: {
value: 'mock hyperparameters',
index: 0
@@ -121,4 +121,4 @@ describe('Unit Test for LocalTrainingService', () => {
it('Test multiphaseSupported', () => {
chai.expect(localTrainingService.isMultiPhaseJobSupported).to.be.equals(true)
})
-});
\ No newline at end of file
+});
diff --git a/src/nni_manager/training_service/test/paiTrainingService.test.ts b/src/nni_manager/training_service/test/paiTrainingService.test.ts
index e55fe9e483..2a52362d42 100644
--- a/src/nni_manager/training_service/test/paiTrainingService.test.ts
+++ b/src/nni_manager/training_service/test/paiTrainingService.test.ts
@@ -24,6 +24,7 @@ import * as chaiAsPromised from 'chai-as-promised';
import * as fs from 'fs';
import * as tmp from 'tmp';
import * as component from '../../common/component';
+import { TrialJobApplicationForm } from '../../common/trainingService';
import { cleanupUnitTest, prepareUnitTest } from '../../common/utils';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { PAITrainingService } from '../pai/paiTrainingService';
@@ -84,12 +85,16 @@ describe('Unit Test for PAITrainingService', () => {
console.log(`paiCluster is ${paiCluster}`)
await paiTrainingService.setClusterMetadata(TrialConfigMetadataKey.PAI_CLUSTER_CONFIG, paiCluster);
await paiTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, paiTrialConfig);
+ const form: TrialJobApplicationForm = {
+ sequenceId: 0,
+ hyperParameters: { value: '', index: 0 }
+ };
try {
- const trialDetail = await paiTrainingService.submitTrialJob({jobType : 'TRIAL'});
+ const trialDetail = await paiTrainingService.submitTrialJob(form);
chai.expect(trialDetail.status).to.be.equals('WAITING');
} catch(error) {
console.log('Submit job failed:' + error);
chai.assert(error)
}
});
-});
\ No newline at end of file
+});
diff --git a/src/nni_manager/training_service/test/remoteMachineTrainingService.test.ts b/src/nni_manager/training_service/test/remoteMachineTrainingService.test.ts
index 7509ea2ade..f8f5025e49 100644
--- a/src/nni_manager/training_service/test/remoteMachineTrainingService.test.ts
+++ b/src/nni_manager/training_service/test/remoteMachineTrainingService.test.ts
@@ -99,11 +99,11 @@ describe('Unit Test for RemoteMachineTrainingService', () => {
await remoteMachineTrainingService.setClusterMetadata(
TrialConfigMetadataKey.TRIAL_CONFIG, `{"command":"sleep 1h && echo ","codeDir":"${localCodeDir}","gpuNum":1}`);
const form: TrialJobApplicationForm = {
- jobType: 'TRIAL',
- hyperParameters: {
- value: 'mock hyperparameters',
- index: 0
- }
+ sequenceId: 0,
+ hyperParameters: {
+ value: 'mock hyperparameters',
+ index: 0
+ }
};
const trialJob = await remoteMachineTrainingService.submitTrialJob(form);
@@ -137,7 +137,7 @@ describe('Unit Test for RemoteMachineTrainingService', () => {
// submit job
const form: TrialJobApplicationForm = {
- jobType: 'TRIAL',
+ sequenceId: 0,
hyperParameters: {
value: 'mock hyperparameters',
index: 0
diff --git a/src/sdk/pynni/nni/constants.py b/src/sdk/pynni/nni/constants.py
index ab726baa1b..5fc515da7b 100644
--- a/src/sdk/pynni/nni/constants.py
+++ b/src/sdk/pynni/nni/constants.py
@@ -30,7 +30,8 @@
'NetworkMorphism': 'nni.networkmorphism_tuner.networkmorphism_tuner',
'Curvefitting': 'nni.curvefitting_assessor.curvefitting_assessor',
'MetisTuner': 'nni.metis_tuner.metis_tuner',
- 'GPTuner': 'nni.gp_tuner.gp_tuner'
+ 'GPTuner': 'nni.gp_tuner.gp_tuner',
+ 'PPOTuner': 'nni.ppo_tuner.ppo_tuner'
}
ClassName = {
@@ -44,6 +45,7 @@
'NetworkMorphism':'NetworkMorphismTuner',
'MetisTuner':'MetisTuner',
'GPTuner':'GPTuner',
+ 'PPOTuner': 'PPOTuner',
'Medianstop': 'MedianstopAssessor',
'Curvefitting': 'CurvefittingAssessor'
diff --git a/src/sdk/pynni/nni/gp_tuner/target_space.py b/src/sdk/pynni/nni/gp_tuner/target_space.py
index 831bc335df..56481fa691 100644
--- a/src/sdk/pynni/nni/gp_tuner/target_space.py
+++ b/src/sdk/pynni/nni/gp_tuner/target_space.py
@@ -55,6 +55,14 @@ def __init__(self, pbounds, random_state=None):
[item[1] for item in sorted(pbounds.items(), key=lambda x: x[0])]
)
+ # check values type
+ for _bound in self._bounds:
+ if _bound['_type'] == 'choice':
+ try:
+ [float(val) for val in _bound['_value']]
+ except ValueError:
+ raise ValueError("GP Tuner supports only numerical values")
+
# preallocated memory for X and Y points
self._params = np.empty(shape=(0, self.dim))
self._target = np.empty(shape=(0))
diff --git a/src/sdk/pynni/nni/hyperopt_tuner/hyperopt_tuner.py b/src/sdk/pynni/nni/hyperopt_tuner/hyperopt_tuner.py
index 7d1e6f7caa..9c54e03df8 100644
--- a/src/sdk/pynni/nni/hyperopt_tuner/hyperopt_tuner.py
+++ b/src/sdk/pynni/nni/hyperopt_tuner/hyperopt_tuner.py
@@ -27,6 +27,7 @@
import hyperopt as hp
import numpy as np
from nni.tuner import Tuner
+from nni.nas_utils import rewrite_nas_space
from nni.utils import NodeType, OptimizeMode, extract_scalar_reward, split_index
logger = logging.getLogger('hyperopt_AutoML')
@@ -240,6 +241,7 @@ def _choose_tuner(self, algorithm_name):
return hp.anneal.suggest
raise RuntimeError('Not support tuner algorithm in hyperopt.')
+ @rewrite_nas_space
def update_search_space(self, search_space):
"""
Update search space definition in tuner by search_space in parameters.
diff --git a/src/sdk/pynni/nni/medianstop_assessor/medianstop_assessor.py b/src/sdk/pynni/nni/medianstop_assessor/medianstop_assessor.py
index 46aeadf1ea..67658a6f60 100644
--- a/src/sdk/pynni/nni/medianstop_assessor/medianstop_assessor.py
+++ b/src/sdk/pynni/nni/medianstop_assessor/medianstop_assessor.py
@@ -79,7 +79,7 @@ def trial_end(self, trial_job_id, success):
self.completed_avg_history[trial_job_id].append(history_sum / cnt)
self.running_history.pop(trial_job_id)
else:
- logger.warning('trial_end: trial_job_id does not in running_history')
+ logger.warning('trial_end: trial_job_id does not exist in running_history')
def assess_trial(self, trial_job_id, trial_history):
"""assess_trial
@@ -112,7 +112,7 @@ def assess_trial(self, trial_job_id, trial_history):
logger.exception(error)
except Exception as error:
logger.warning('unrecognized exception in medianstop_assessor:')
- logger.excpetion(error)
+ logger.exception(error)
self._update_data(trial_job_id, num_trial_history)
if self.high_better:
diff --git a/src/sdk/pynni/nni/msg_dispatcher.py b/src/sdk/pynni/nni/msg_dispatcher.py
index 0beda3f154..1467b27695 100644
--- a/src/sdk/pynni/nni/msg_dispatcher.py
+++ b/src/sdk/pynni/nni/msg_dispatcher.py
@@ -22,6 +22,7 @@
from collections import defaultdict
import json_tricks
+from nni import NoMoreTrialError
from .protocol import CommandType, send
from .msg_dispatcher_base import MsgDispatcherBase
from .assessor import AssessResult
@@ -100,11 +101,16 @@ def handle_initialize(self, data):
self.tuner.update_search_space(data)
send(CommandType.Initialized, '')
+ def send_trial_callback(self, id, params):
+ """For tuner to issue trial config when the config is generated
+ """
+ send(CommandType.NewTrialJob, _pack_parameter(id, params))
+
def handle_request_trial_jobs(self, data):
# data: number or trial jobs
ids = [_create_parameter_id() for _ in range(data)]
_logger.debug("requesting for generating params of {}".format(ids))
- params_list = self.tuner.generate_multiple_parameters(ids)
+ params_list = self.tuner.generate_multiple_parameters(ids, st_callback=self.send_trial_callback)
for i, _ in enumerate(params_list):
send(CommandType.NewTrialJob, _pack_parameter(ids[i], params_list[i]))
@@ -144,7 +150,10 @@ def handle_report_metric_data(self, data):
assert data['trial_job_id'] is not None
assert data['parameter_index'] is not None
param_id = _create_parameter_id()
- param = self.tuner.generate_parameters(param_id, trial_job_id=data['trial_job_id'])
+ try:
+ param = self.tuner.generate_parameters(param_id, trial_job_id=data['trial_job_id'])
+ except NoMoreTrialError:
+ param = None
send(CommandType.SendTrialJobParameter, _pack_parameter(param_id, param, trial_job_id=data['trial_job_id'], parameter_index=data['parameter_index']))
else:
raise ValueError('Data type not supported: {}'.format(data['type']))
@@ -171,15 +180,15 @@ def _handle_final_metric_data(self, data):
id_ = data['parameter_id']
value = data['value']
if id_ in _customized_parameter_ids:
- if multi_phase_enabled():
- self.tuner.receive_customized_trial_result(id_, _trial_params[id_], value, trial_job_id=data['trial_job_id'])
- else:
- self.tuner.receive_customized_trial_result(id_, _trial_params[id_], value)
+ if not hasattr(self.tuner, '_accept_customized'):
+ self.tuner._accept_customized = False
+ if not self.tuner._accept_customized:
+ _logger.info('Customized trial job %s ignored by tuner', id_)
+ return
+ customized = True
else:
- if multi_phase_enabled():
- self.tuner.receive_trial_result(id_, _trial_params[id_], value, trial_job_id=data['trial_job_id'])
- else:
- self.tuner.receive_trial_result(id_, _trial_params[id_], value)
+ customized = False
+ self.tuner.receive_trial_result(id_, _trial_params[id_], value, customized=customized, trial_job_id=data.get('trial_job_id'))
def _handle_intermediate_metric_data(self, data):
"""Call assessor to process intermediate results
diff --git a/src/sdk/pynni/nni/nas_utils.py b/src/sdk/pynni/nni/nas_utils.py
index efc25194a4..70e66b318e 100644
--- a/src/sdk/pynni/nni/nas_utils.py
+++ b/src/sdk/pynni/nni/nas_utils.py
@@ -17,10 +17,16 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
+import functools
+import logging
from . import trial
+_logger = logging.getLogger(__name__)
+_MUTABLE_LAYER_SPACE_PREFIX = "_mutable_layer"
+
+
def classic_mode(
mutable_id,
mutable_layer_id,
@@ -34,13 +40,11 @@ def classic_mode(
without touching the full model graph.'''
if trial.get_current_parameter() is None:
trial.get_next_parameter()
- mutable_block = trial.get_current_parameter(mutable_id)
- chosen_layer = mutable_block[mutable_layer_id]["chosen_layer"]
- chosen_inputs = mutable_block[mutable_layer_id]["chosen_inputs"]
- real_chosen_inputs = [optional_inputs[input_name]
- for input_name in chosen_inputs]
- layer_out = funcs[chosen_layer](
- [fixed_inputs, real_chosen_inputs], **funcs_args[chosen_layer])
+
+ chosen_layer, chosen_inputs = _get_layer_and_inputs_from_tuner(mutable_id, mutable_layer_id,
+ list(optional_inputs.keys()))
+ real_chosen_inputs = [optional_inputs[input_name] for input_name in chosen_inputs]
+ layer_out = funcs[chosen_layer]([fixed_inputs, real_chosen_inputs], **funcs_args[chosen_layer])
return layer_out
@@ -173,20 +177,44 @@ def reload_tensorflow_variables(tf, session):
tf: tensorflow module
'''
subgraph_from_tuner = trial.get_next_parameter()
- for mutable_id, mutable_block in subgraph_from_tuner.items():
+ mutable_layers = set()
+ for subgraph_key in subgraph_from_tuner:
+ if "/" in subgraph_key:
+ # has to remove the last, could be layer_choice or whatever
+ mutable_id, mutable_layer_id = _decompose_general_key(subgraph_key[:subgraph_key.rfind("/")])
+ if mutable_id is not None:
+ mutable_layers.add((mutable_id, mutable_layer_id))
+ mutable_layers = sorted(list(mutable_layers))
+ for mutable_id, mutable_layer_id in mutable_layers:
if mutable_id not in name_space:
+ _logger.warning("{} not found in name space".format(mutable_id))
continue
- for mutable_layer_id, mutable_layer in mutable_block.items():
- name_prefix = "{}_{}".format(mutable_id, mutable_layer_id)
- # extract layer information from the subgraph sampled by tuner
- chosen_layer = name_space[name_prefix]['funcs'].index(
- mutable_layer["chosen_layer"])
- chosen_inputs = [1 if inp in mutable_layer["chosen_inputs"]
- else 0 for inp in name_space[name_prefix]['optional_inputs']]
- # load these information into pre-defined tensorflow variables
- tf_variables[name_prefix]['funcs'].load(chosen_layer, session)
- tf_variables[name_prefix]['optional_inputs'].load(
- chosen_inputs, session)
+ name_prefix = "{}_{}".format(mutable_id, mutable_layer_id)
+ # get optional inputs names
+ optional_inputs = name_space[name_prefix]['optional_inputs']
+ # extract layer information from the subgraph sampled by tuner
+ chosen_layer, chosen_inputs = _get_layer_and_inputs_from_tuner(mutable_id, mutable_layer_id, optional_inputs)
+ chosen_layer = name_space[name_prefix]['funcs'].index(chosen_layer)
+ chosen_inputs = [1 if inp in chosen_inputs else 0 for inp in optional_inputs]
+ # load these information into pre-defined tensorflow variables
+ tf_variables[name_prefix]['funcs'].load(chosen_layer, session)
+ tf_variables[name_prefix]['optional_inputs'].load(
+ chosen_inputs, session)
+
+
+def _construct_general_key(mutable_id, mutable_layer_id):
+ # Mutable layer key in a general (search space) format
+ # that is, prefix/mutable_id/mutable_layer_id
+ return _MUTABLE_LAYER_SPACE_PREFIX + "/" + mutable_id + "/" + mutable_layer_id
+
+
+def _decompose_general_key(key):
+ # inverse operation of above
+ if not key.startswith(_MUTABLE_LAYER_SPACE_PREFIX):
+ return None, None
+ else:
+ _, mutable_id, mutable_layer_id = key.split("/", maxsplit=2)
+ return mutable_id, mutable_layer_id
def darts_training(tf, session, loss, feed_dict):
@@ -205,4 +233,107 @@ def training_update(nas_mode, tf=None, session=None, loss=None, feed_dict=None):
if nas_mode == 'darts_mode':
darts_training(tf, session, loss, feed_dict)
elif nas_mode == 'enas_mode':
- reload_tensorflow_variables(tf, session)
\ No newline at end of file
+ reload_tensorflow_variables(tf, session)
+
+
+def _get_layer_and_inputs_from_tuner(mutable_id, mutable_layer_id, optional_inputs):
+ # optional_inputs should be name(key)s of the optional inputs
+ try:
+ mutable_block = trial.get_current_parameter(mutable_id)
+
+ # There is a NAS tuner
+ chosen_layer = mutable_block[mutable_layer_id]["chosen_layer"]
+ chosen_inputs = mutable_block[mutable_layer_id]["chosen_inputs"]
+ except KeyError:
+ # Try to find converted NAS parameters
+ params = trial.get_current_parameter()
+ expected_prefix = _construct_general_key(mutable_id, mutable_layer_id)
+ chosen_layer = params[expected_prefix + "/layer_choice"]
+
+ # find how many to choose
+ optional_input_size = int(params[expected_prefix + "/optional_input_size"]) # convert uniform to randint
+
+ # find who to choose, can duplicate
+ optional_input_state = params[expected_prefix + "/optional_input_chosen_state"]
+ chosen_inputs = []
+ # make sure dict -> list produce stable result by sorting
+ optional_inputs_keys = sorted(optional_inputs)
+ for i in range(optional_input_size):
+ chosen_inputs.append(optional_inputs_keys[optional_input_state % len(optional_inputs)])
+ optional_input_state //= len(optional_inputs)
+
+ _logger.info("%s_%s: layer: %s, optional inputs: %s" % (mutable_id, mutable_layer_id,
+ chosen_layer, chosen_inputs))
+ return chosen_layer, chosen_inputs
+
+
+def convert_nas_search_space(search_space):
+ """
+ Args:
+ param search_space: raw search space
+ return: the new search space, mutable_layers will be converted into choice
+ """
+ ret = dict()
+ for k, v in search_space.items():
+ if "_type" not in v:
+ # this should not happen
+ _logger.warning("There is no _type in one of your search space values with key '%s'"
+ ". Please check your search space" % k)
+ ret[k] = v
+ elif v["_type"] != "mutable_layer":
+ ret[k] = v
+ else:
+ _logger.info("Converting mutable_layer search space with key '%s'" % k)
+ # v["_value"] looks like {'mutable_layer_1': {'layer_choice': ...} ...}
+ values = v["_value"]
+ for layer_name, layer_data in values.items():
+ # there should be at most layer_choice, optional_inputs, optional_input_size in layer_data
+
+ # add "_mutable_layer" as prefix so that they can be recovered later
+ layer_key = _construct_general_key(k, layer_name)
+
+ if layer_data.get("layer_choice"): # filter out empty choice and no choice
+ layer_choice = layer_data["layer_choice"]
+ else:
+ raise ValueError("No layer choice found in %s" % layer_key)
+
+ if layer_data.get("optional_input_size"):
+ input_size = layer_data["optional_input_size"]
+ if isinstance(input_size, int):
+ input_size = [input_size, input_size]
+ if input_size[0] > input_size[1] or input_size[0] < 0:
+ _logger.error("Might not be able to handle optional_input_size < 0, please double check")
+ input_size[1] += 1
+ else:
+ _logger.info("Optional input choices are set to empty by default in %s" % layer_key)
+ input_size = [0, 1]
+
+ if layer_data.get("optional_inputs"):
+ total_state_size = len(layer_data["optional_inputs"]) ** (input_size[1] - 1)
+ else:
+ _logger.info("Optional inputs not found in %s" % layer_key)
+ total_state_size = 1
+
+ converted = {
+ layer_key + "/layer_choice": {
+ "_type": "choice", "_value": layer_choice
+ },
+ layer_key + "/optional_input_size": {
+ "_type": "randint", "_value": input_size
+ },
+ layer_key + "/optional_input_chosen_state": {
+ "_type": "randint", "_value": [0, total_state_size]
+ }
+ }
+ _logger.info(converted)
+ ret.update(converted)
+
+ return ret
+
+
+def rewrite_nas_space(func):
+ @functools.wraps(func)
+ def wrap(self, search_space):
+ search_space = convert_nas_search_space(search_space)
+ return func(self, search_space)
+ return wrap
diff --git a/src/sdk/pynni/nni/ppo_tuner/__init__.py b/src/sdk/pynni/nni/ppo_tuner/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/sdk/pynni/nni/ppo_tuner/distri.py b/src/sdk/pynni/nni/ppo_tuner/distri.py
new file mode 100644
index 0000000000..4666acc2da
--- /dev/null
+++ b/src/sdk/pynni/nni/ppo_tuner/distri.py
@@ -0,0 +1,198 @@
+# Copyright (c) Microsoft Corporation
+# All rights reserved.
+#
+# MIT License
+#
+# Permission is hereby granted, free of charge,
+# to any person obtaining a copy of this software and associated
+# documentation files (the "Software"), to deal in the Software without restriction,
+# including without limitation the rights to use, copy, modify, merge, publish,
+# distribute, sublicense, and/or sell copies of the Software, and
+# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+# The above copyright notice and this permission notice shall be included
+# in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
+# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
+# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+"""
+functions for sampling from hidden state
+"""
+
+import tensorflow as tf
+
+from .util import fc
+
+
+class Pd:
+ """
+ A particular probability distribution
+ """
+ def flatparam(self):
+ raise NotImplementedError
+ def mode(self):
+ raise NotImplementedError
+ def neglogp(self, x):
+ # Usually it's easier to define the negative logprob
+ raise NotImplementedError
+ def kl(self, other):
+ raise NotImplementedError
+ def entropy(self):
+ raise NotImplementedError
+ def sample(self):
+ raise NotImplementedError
+ def logp(self, x):
+ return - self.neglogp(x)
+ def get_shape(self):
+ return self.flatparam().shape
+ @property
+ def shape(self):
+ return self.get_shape()
+ def __getitem__(self, idx):
+ return self.__class__(self.flatparam()[idx])
+
+class PdType:
+ """
+ Parametrized family of probability distributions
+ """
+ def pdclass(self):
+ raise NotImplementedError
+ def pdfromflat(self, flat, mask, nsteps, size, is_act_model):
+ return self.pdclass()(flat, mask, nsteps, size, is_act_model)
+ def pdfromlatent(self, latent_vector, init_scale, init_bias):
+ raise NotImplementedError
+ def param_shape(self):
+ raise NotImplementedError
+ def sample_shape(self):
+ raise NotImplementedError
+ def sample_dtype(self):
+ raise NotImplementedError
+
+ def param_placeholder(self, prepend_shape, name=None):
+ return tf.placeholder(dtype=tf.float32, shape=prepend_shape+self.param_shape(), name=name)
+ def sample_placeholder(self, prepend_shape, name=None):
+ return tf.placeholder(dtype=self.sample_dtype(), shape=prepend_shape+self.sample_shape(), name=name)
+
+class CategoricalPd(Pd):
+ """
+ categorical prossibility distribution
+ """
+ def __init__(self, logits, mask_npinf, nsteps, size, is_act_model):
+ self.logits = logits
+ self.mask_npinf = mask_npinf
+ self.nsteps = nsteps
+ self.size = size
+ self.is_act_model = is_act_model
+ def flatparam(self):
+ return self.logits
+ def mode(self):
+ return tf.argmax(self.logits, axis=-1)
+
+ @property
+ def mean(self):
+ return tf.nn.softmax(self.logits)
+ def neglogp(self, x):
+ """
+ return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x)
+ Note: we can't use sparse_softmax_cross_entropy_with_logits because
+ the implementation does not allow second-order derivatives...
+ """
+ if x.dtype in {tf.uint8, tf.int32, tf.int64}:
+ # one-hot encoding
+ x_shape_list = x.shape.as_list()
+ logits_shape_list = self.logits.get_shape().as_list()[:-1]
+ for xs, ls in zip(x_shape_list, logits_shape_list):
+ if xs is not None and ls is not None:
+ assert xs == ls, 'shape mismatch: {} in x vs {} in logits'.format(xs, ls)
+
+ x = tf.one_hot(x, self.logits.get_shape().as_list()[-1])
+ else:
+ # already encoded
+ assert x.shape.as_list() == self.logits.shape.as_list()
+
+ return tf.nn.softmax_cross_entropy_with_logits_v2(
+ logits=self.logits,
+ labels=x)
+
+ def kl(self, other):
+ """kl"""
+ a0 = self.logits - tf.reduce_max(self.logits, axis=-1, keepdims=True)
+ a1 = other.logits - tf.reduce_max(other.logits, axis=-1, keepdims=True)
+ ea0 = tf.exp(a0)
+ ea1 = tf.exp(a1)
+ z0 = tf.reduce_sum(ea0, axis=-1, keepdims=True)
+ z1 = tf.reduce_sum(ea1, axis=-1, keepdims=True)
+ p0 = ea0 / z0
+ return tf.reduce_sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), axis=-1)
+
+ def entropy(self):
+ """compute entropy"""
+ a0 = self.logits - tf.reduce_max(self.logits, axis=-1, keepdims=True)
+ ea0 = tf.exp(a0)
+ z0 = tf.reduce_sum(ea0, axis=-1, keepdims=True)
+ p0 = ea0 / z0
+ return tf.reduce_sum(p0 * (tf.log(z0) - a0), axis=-1)
+
+ def sample(self):
+ """sample from logits"""
+ if not self.is_act_model:
+ re_res = tf.reshape(self.logits, [-1, self.nsteps, self.size])
+ masked_res = tf.math.add(re_res, self.mask_npinf)
+ re_masked_res = tf.reshape(masked_res, [-1, self.size])
+
+ u = tf.random_uniform(tf.shape(re_masked_res), dtype=self.logits.dtype)
+ return tf.argmax(re_masked_res - tf.log(-tf.log(u)), axis=-1)
+ else:
+ u = tf.random_uniform(tf.shape(self.logits), dtype=self.logits.dtype)
+ return tf.argmax(self.logits - tf.log(-tf.log(u)), axis=-1)
+
+ @classmethod
+ def fromflat(cls, flat):
+ return cls(flat)
+
+class CategoricalPdType(PdType):
+ """
+ to create CategoricalPd
+ """
+ def __init__(self, ncat, nsteps, np_mask, is_act_model):
+ self.ncat = ncat
+ self.nsteps = nsteps
+ self.np_mask = np_mask
+ self.is_act_model = is_act_model
+ def pdclass(self):
+ return CategoricalPd
+
+ def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
+ """add fc and create CategoricalPd"""
+ pdparam, mask, mask_npinf = _matching_fc(latent_vector, 'pi', self.ncat, self.nsteps,
+ init_scale=init_scale, init_bias=init_bias,
+ np_mask=self.np_mask, is_act_model=self.is_act_model)
+ return self.pdfromflat(pdparam, mask_npinf, self.nsteps, self.ncat, self.is_act_model), pdparam, mask, mask_npinf
+
+ def param_shape(self):
+ return [self.ncat]
+ def sample_shape(self):
+ return []
+ def sample_dtype(self):
+ return tf.int32
+
+def _matching_fc(tensor, name, size, nsteps, init_scale, init_bias, np_mask, is_act_model):
+ """
+ add fc op, and add mask op when not in action mode
+ """
+ if tensor.shape[-1] == size:
+ assert False
+ return tensor
+ else:
+ mask = tf.get_variable("act_mask", dtype=tf.float32, initializer=np_mask[0], trainable=False)
+ mask_npinf = tf.get_variable("act_mask_npinf", dtype=tf.float32, initializer=np_mask[1], trainable=False)
+ res = fc(tensor, name, size, init_scale=init_scale, init_bias=init_bias)
+ if not is_act_model:
+ re_res = tf.reshape(res, [-1, nsteps, size])
+ masked_res = tf.math.multiply(re_res, mask)
+ re_masked_res = tf.reshape(masked_res, [-1, size])
+ return re_masked_res, mask, mask_npinf
+ else:
+ return res, mask, mask_npinf
diff --git a/src/sdk/pynni/nni/ppo_tuner/model.py b/src/sdk/pynni/nni/ppo_tuner/model.py
new file mode 100644
index 0000000000..330f10369d
--- /dev/null
+++ b/src/sdk/pynni/nni/ppo_tuner/model.py
@@ -0,0 +1,166 @@
+# Copyright (c) Microsoft Corporation
+# All rights reserved.
+#
+# MIT License
+#
+# Permission is hereby granted, free of charge,
+# to any person obtaining a copy of this software and associated
+# documentation files (the "Software"), to deal in the Software without restriction,
+# including without limitation the rights to use, copy, modify, merge, publish,
+# distribute, sublicense, and/or sell copies of the Software, and
+# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+# The above copyright notice and this permission notice shall be included
+# in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
+# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
+# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+"""
+the main model of policy/value network
+"""
+
+import tensorflow as tf
+
+from .util import initialize, get_session
+
+class Model:
+ """
+ We use this object to :
+ __init__:
+ - Creates the step_model
+ - Creates the train_model
+
+ train():
+ - Make the training part (feedforward and retropropagation of gradients)
+
+ save/load():
+ - Save load the model
+ """
+ def __init__(self, *, policy, nbatch_act, nbatch_train,
+ nsteps, ent_coef, vf_coef, max_grad_norm, microbatch_size=None, np_mask=None):
+ """
+ init
+ """
+ self.sess = sess = get_session()
+
+ with tf.variable_scope('ppo2_model', reuse=tf.AUTO_REUSE):
+ # CREATE OUR TWO MODELS
+ # act_model that is used for sampling
+ act_model = policy(nbatch_act, 1, sess, np_mask=np_mask, is_act_model=True)
+
+ # Train model for training
+ if microbatch_size is None:
+ train_model = policy(nbatch_train, nsteps, sess, np_mask=np_mask, is_act_model=False)
+ else:
+ train_model = policy(microbatch_size, nsteps, sess, np_mask=np_mask, is_act_model=False)
+
+ # CREATE THE PLACEHOLDERS
+ self.A = A = train_model.pdtype.sample_placeholder([None])
+ self.ADV = ADV = tf.placeholder(tf.float32, [None])
+ self.R = R = tf.placeholder(tf.float32, [None])
+ # Keep track of old actor
+ self.OLDNEGLOGPAC = OLDNEGLOGPAC = tf.placeholder(tf.float32, [None])
+ # Keep track of old critic
+ self.OLDVPRED = OLDVPRED = tf.placeholder(tf.float32, [None])
+ self.LR = LR = tf.placeholder(tf.float32, [])
+ # Cliprange
+ self.CLIPRANGE = CLIPRANGE = tf.placeholder(tf.float32, [])
+
+ neglogpac = train_model.pd.neglogp(A)
+
+ # Calculate the entropy
+ # Entropy is used to improve exploration by limiting the premature convergence to suboptimal policy.
+ entropy = tf.reduce_mean(train_model.pd.entropy())
+
+ # CALCULATE THE LOSS
+ # Total loss = Policy gradient loss - entropy * entropy coefficient + Value coefficient * value loss
+
+ # Clip the value to reduce variability during Critic training
+ # Get the predicted value
+ vpred = train_model.vf
+ vpredclipped = OLDVPRED + tf.clip_by_value(train_model.vf - OLDVPRED, - CLIPRANGE, CLIPRANGE)
+ # Unclipped value
+ vf_losses1 = tf.square(vpred - R)
+ # Clipped value
+ vf_losses2 = tf.square(vpredclipped - R)
+
+ vf_loss = .5 * tf.reduce_mean(tf.maximum(vf_losses1, vf_losses2))
+
+ # Calculate ratio (pi current policy / pi old policy)
+ ratio = tf.exp(OLDNEGLOGPAC - neglogpac)
+
+ # Defining Loss = - J is equivalent to max J
+ pg_losses = -ADV * ratio
+
+ pg_losses2 = -ADV * tf.clip_by_value(ratio, 1.0 - CLIPRANGE, 1.0 + CLIPRANGE)
+
+ # Final PG loss
+ pg_loss = tf.reduce_mean(tf.maximum(pg_losses, pg_losses2))
+ approxkl = .5 * tf.reduce_mean(tf.square(neglogpac - OLDNEGLOGPAC))
+ clipfrac = tf.reduce_mean(tf.to_float(tf.greater(tf.abs(ratio - 1.0), CLIPRANGE)))
+
+ # Total loss
+ loss = pg_loss - entropy * ent_coef + vf_loss * vf_coef
+
+ # UPDATE THE PARAMETERS USING LOSS
+ # 1. Get the model parameters
+ params = tf.trainable_variables('ppo2_model')
+ # 2. Build our trainer
+ self.trainer = tf.train.AdamOptimizer(learning_rate=LR, epsilon=1e-5)
+ # 3. Calculate the gradients
+ grads_and_var = self.trainer.compute_gradients(loss, params)
+ grads, var = zip(*grads_and_var)
+
+ if max_grad_norm is not None:
+ # Clip the gradients (normalize)
+ grads, _grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)
+ grads_and_var = list(zip(grads, var))
+ # zip aggregate each gradient with parameters associated
+ # For instance zip(ABCD, xyza) => Ax, By, Cz, Da
+
+ self.grads = grads
+ self.var = var
+ self._train_op = self.trainer.apply_gradients(grads_and_var)
+ self.loss_names = ['policy_loss', 'value_loss', 'policy_entropy', 'approxkl', 'clipfrac']
+ self.stats_list = [pg_loss, vf_loss, entropy, approxkl, clipfrac]
+
+
+ self.train_model = train_model
+ self.act_model = act_model
+ self.step = act_model.step
+ self.value = act_model.value
+ self.initial_state = act_model.initial_state
+
+ initialize()
+
+ def train(self, lr, cliprange, obs, returns, masks, actions, values, neglogpacs, states=None):
+ """
+ train the model.
+ Here we calculate advantage A(s,a) = R + yV(s') - V(s)
+ Returns = R + yV(s')
+ """
+ advs = returns - values
+
+ # Normalize the advantages
+ advs = (advs - advs.mean()) / (advs.std() + 1e-8)
+
+ td_map = {
+ self.train_model.X : obs,
+ self.A : actions,
+ self.ADV : advs,
+ self.R : returns,
+ self.LR : lr,
+ self.CLIPRANGE : cliprange,
+ self.OLDNEGLOGPAC : neglogpacs,
+ self.OLDVPRED : values
+ }
+ if states is not None:
+ td_map[self.train_model.S] = states
+ td_map[self.train_model.M] = masks
+
+ return self.sess.run(
+ self.stats_list + [self._train_op],
+ td_map
+ )[:-1]
diff --git a/src/sdk/pynni/nni/ppo_tuner/policy.py b/src/sdk/pynni/nni/ppo_tuner/policy.py
new file mode 100644
index 0000000000..65e2db414e
--- /dev/null
+++ b/src/sdk/pynni/nni/ppo_tuner/policy.py
@@ -0,0 +1,219 @@
+# Copyright (c) Microsoft Corporation
+# All rights reserved.
+#
+# MIT License
+#
+# Permission is hereby granted, free of charge,
+# to any person obtaining a copy of this software and associated
+# documentation files (the "Software"), to deal in the Software without restriction,
+# including without limitation the rights to use, copy, modify, merge, publish,
+# distribute, sublicense, and/or sell copies of the Software, and
+# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+# The above copyright notice and this permission notice shall be included
+# in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
+# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
+# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+"""
+build policy/value network from model
+"""
+
+import tensorflow as tf
+
+from .distri import CategoricalPdType
+from .util import lstm_model, fc, observation_placeholder, adjust_shape
+
+
+class PolicyWithValue:
+ """
+ Encapsulates fields and methods for RL policy and value function estimation with shared parameters
+ """
+
+ def __init__(self, env, observations, latent, estimate_q=False, vf_latent=None, sess=None, np_mask=None, is_act_model=False, **tensors):
+ """
+ Parameters:
+ ----------
+ env: RL environment
+ observations: tensorflow placeholder in which the observations will be fed
+ latent: latent state from which policy distribution parameters should be inferred
+ vf_latent: latent state from which value function should be inferred (if None, then latent is used)
+ sess: tensorflow session to run calculations in (if None, default session is used)
+ **tensors: tensorflow tensors for additional attributes such as state or mask
+ """
+
+ self.X = observations
+ self.state = tf.constant([])
+ self.initial_state = None
+ self.__dict__.update(tensors)
+
+ vf_latent = vf_latent if vf_latent is not None else latent
+
+ vf_latent = tf.layers.flatten(vf_latent)
+ latent = tf.layers.flatten(latent)
+
+ # Based on the action space, will select what probability distribution type
+ self.np_mask = np_mask
+ self.pdtype = CategoricalPdType(env.action_space.n, env.nsteps, np_mask, is_act_model)
+
+ self.act_latent = latent
+ self.nh = env.action_space.n
+
+ self.pd, self.pi, self.mask, self.mask_npinf = self.pdtype.pdfromlatent(latent, init_scale=0.01)
+
+ # Take an action
+ self.action = self.pd.sample()
+
+ # Calculate the neg log of our probability
+ self.neglogp = self.pd.neglogp(self.action)
+ self.sess = sess or tf.get_default_session()
+
+ assert estimate_q is False
+ self.vf = fc(vf_latent, 'vf', 1)
+ self.vf = self.vf[:, 0]
+
+ if is_act_model:
+ self._build_model_for_step()
+
+ def _evaluate(self, variables, observation, **extra_feed):
+ sess = self.sess
+ feed_dict = {self.X: adjust_shape(self.X, observation)}
+ for inpt_name, data in extra_feed.items():
+ if inpt_name in self.__dict__.keys():
+ inpt = self.__dict__[inpt_name]
+ if isinstance(inpt, tf.Tensor) and inpt._op.type == 'Placeholder':
+ feed_dict[inpt] = adjust_shape(inpt, data)
+
+ return sess.run(variables, feed_dict)
+
+ def _build_model_for_step(self):
+ # multiply with weight and apply mask on self.act_latent to generate
+ self.act_step = step = tf.placeholder(shape=(), dtype=tf.int64, name='act_step')
+ with tf.variable_scope('pi', reuse=tf.AUTO_REUSE):
+ from .util import ortho_init
+ nin = self.act_latent.get_shape()[1].value
+ w = tf.get_variable("w", [nin, self.nh], initializer=ortho_init(0.01))
+ b = tf.get_variable("b", [self.nh], initializer=tf.constant_initializer(0.0))
+ logits = tf.matmul(self.act_latent, w)+b
+ piece = tf.slice(self.mask, [step, 0], [1, self.nh])
+ re_piece = tf.reshape(piece, [-1])
+ masked_logits = tf.math.multiply(logits, re_piece)
+
+ npinf_piece = tf.slice(self.mask_npinf, [step, 0], [1, self.nh])
+ re_npinf_piece = tf.reshape(npinf_piece, [-1])
+
+ def sample(logits, mask_npinf):
+ new_logits = tf.math.add(logits, mask_npinf)
+ u = tf.random_uniform(tf.shape(new_logits), dtype=logits.dtype)
+ return tf.argmax(new_logits - tf.log(-tf.log(u)), axis=-1)
+
+ def neglogp(logits, x):
+ # return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x)
+ # Note: we can't use sparse_softmax_cross_entropy_with_logits because
+ # the implementation does not allow second-order derivatives...
+ if x.dtype in {tf.uint8, tf.int32, tf.int64}:
+ # one-hot encoding
+ x_shape_list = x.shape.as_list()
+ logits_shape_list = logits.get_shape().as_list()[:-1]
+ for xs, ls in zip(x_shape_list, logits_shape_list):
+ if xs is not None and ls is not None:
+ assert xs == ls, 'shape mismatch: {} in x vs {} in logits'.format(xs, ls)
+
+ x = tf.one_hot(x, logits.get_shape().as_list()[-1])
+ else:
+ # already encoded
+ assert x.shape.as_list() == logits.shape.as_list()
+
+ return tf.nn.softmax_cross_entropy_with_logits_v2(
+ logits=logits,
+ labels=x)
+
+ self.act_action = sample(masked_logits, re_npinf_piece)
+ self.act_neglogp = neglogp(masked_logits, self.act_action)
+
+
+ def step(self, step, observation, **extra_feed):
+ """
+ Compute next action(s) given the observation(s)
+
+ Parameters:
+ ----------
+ observation: observation data (either single or a batch)
+ **extra_feed: additional data such as state or mask (names of the arguments should match the ones in constructor, see __init__)
+
+ Returns:
+ -------
+ (action, value estimate, next state, negative log likelihood of the action under current policy parameters) tuple
+ """
+ extra_feed['act_step'] = step
+ a, v, state, neglogp = self._evaluate([self.act_action, self.vf, self.state, self.act_neglogp], observation, **extra_feed)
+ if state.size == 0:
+ state = None
+ return a, v, state, neglogp
+
+ def value(self, ob, *args, **kwargs):
+ """
+ Compute value estimate(s) given the observation(s)
+
+ Parameters:
+ ----------
+ observation: observation data (either single or a batch)
+ **extra_feed: additional data such as state or mask (names of the arguments should match the ones in constructor, see __init__)
+
+ Returns:
+ -------
+ value estimate
+ """
+ return self._evaluate(self.vf, ob, *args, **kwargs)
+
+
+def build_lstm_policy(model_config, value_network=None, estimate_q=False, **policy_kwargs):
+ """
+ build lstm policy and value network, they share the same lstm network.
+ the parameters all use their default values.
+ """
+ policy_network = lstm_model(**policy_kwargs)
+
+ def policy_fn(nbatch=None, nsteps=None, sess=None, observ_placeholder=None, np_mask=None, is_act_model=False):
+ ob_space = model_config.observation_space
+
+ X = observ_placeholder if observ_placeholder is not None else observation_placeholder(ob_space, batch_size=nbatch)
+
+ extra_tensors = {}
+
+ # encode_observation is not necessary anymore as we use embedding_lookup
+ encoded_x = X
+
+ with tf.variable_scope('pi', reuse=tf.AUTO_REUSE):
+ policy_latent = policy_network(encoded_x, 1, model_config.observation_space.n)
+ if isinstance(policy_latent, tuple):
+ policy_latent, recurrent_tensors = policy_latent
+
+ if recurrent_tensors is not None:
+ # recurrent architecture, need a few more steps
+ nenv = nbatch // nsteps
+ assert nenv > 0, 'Bad input for recurrent policy: batch size {} smaller than nsteps {}'.format(nbatch, nsteps)
+ policy_latent, recurrent_tensors = policy_network(encoded_x, nenv, model_config.observation_space.n)
+ extra_tensors.update(recurrent_tensors)
+
+ _v_net = value_network
+
+ assert _v_net is None or _v_net == 'shared'
+ vf_latent = policy_latent
+
+ policy = PolicyWithValue(
+ env=model_config,
+ observations=X,
+ latent=policy_latent,
+ vf_latent=vf_latent,
+ sess=sess,
+ estimate_q=estimate_q,
+ np_mask=np_mask,
+ is_act_model=is_act_model,
+ **extra_tensors
+ )
+ return policy
+
+ return policy_fn
diff --git a/src/sdk/pynni/nni/ppo_tuner/ppo_tuner.py b/src/sdk/pynni/nni/ppo_tuner/ppo_tuner.py
new file mode 100644
index 0000000000..1bc86ae750
--- /dev/null
+++ b/src/sdk/pynni/nni/ppo_tuner/ppo_tuner.py
@@ -0,0 +1,598 @@
+# Copyright (c) Microsoft Corporation
+# All rights reserved.
+#
+# MIT License
+#
+# Permission is hereby granted, free of charge,
+# to any person obtaining a copy of this software and associated
+# documentation files (the "Software"), to deal in the Software without restriction,
+# including without limitation the rights to use, copy, modify, merge, publish,
+# distribute, sublicense, and/or sell copies of the Software, and
+# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+# The above copyright notice and this permission notice shall be included
+# in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
+# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
+# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+"""
+ppo_tuner.py including:
+ class PPOTuner
+"""
+
+import os
+import copy
+import logging
+import numpy as np
+import json_tricks
+from gym import spaces
+
+import nni
+from nni.tuner import Tuner
+from nni.utils import OptimizeMode, extract_scalar_reward
+
+from .model import Model
+from .util import set_global_seeds
+from .policy import build_lstm_policy
+
+
+logger = logging.getLogger('ppo_tuner_AutoML')
+
+def constfn(val):
+ """wrap as function"""
+ def f(_):
+ return val
+ return f
+
+
+class ModelConfig:
+ """
+ Configurations of the PPO model
+ """
+ def __init__(self):
+ self.observation_space = None
+ self.action_space = None
+ self.num_envs = 0
+ self.nsteps = 0
+
+ self.ent_coef = 0.0
+ self.lr = 3e-4
+ self.vf_coef = 0.5
+ self.max_grad_norm = 0.5
+ self.gamma = 0.99
+ self.lam = 0.95
+ self.cliprange = 0.2
+ self.embedding_size = None # the embedding is for each action
+
+ self.noptepochs = 4 # number of training epochs per update
+ self.total_timesteps = 5000 # number of timesteps (i.e. number of actions taken in the environment)
+ self.nminibatches = 4 # number of training minibatches per update. For recurrent policies,
+ # should be smaller or equal than number of environments run in parallel.
+
+class TrialsInfo:
+ """
+ Informations of each trial from one model inference
+ """
+ def __init__(self, obs, actions, values, neglogpacs, dones, last_value, inf_batch_size):
+ self.iter = 0
+ self.obs = obs
+ self.actions = actions
+ self.values = values
+ self.neglogpacs = neglogpacs
+ self.dones = dones
+ self.last_value = last_value
+
+ self.rewards = None
+ self.returns = None
+
+ self.inf_batch_size = inf_batch_size
+ #self.states = None
+
+ def get_next(self):
+ """
+ get actions of the next trial
+ """
+ if self.iter >= self.inf_batch_size:
+ return None, None
+ actions = []
+ for step in self.actions:
+ actions.append(step[self.iter])
+ self.iter += 1
+ return self.iter - 1, actions
+
+ def update_rewards(self, rewards, returns):
+ """
+ after the trial is finished, reward and return of this trial is updated
+ """
+ self.rewards = rewards
+ self.returns = returns
+
+ def convert_shape(self):
+ """
+ convert shape
+ """
+ def sf01(arr):
+ """
+ swap and then flatten axes 0 and 1
+ """
+ s = arr.shape
+ return arr.swapaxes(0, 1).reshape(s[0] * s[1], *s[2:])
+ self.obs = sf01(self.obs)
+ self.returns = sf01(self.returns)
+ self.dones = sf01(self.dones)
+ self.actions = sf01(self.actions)
+ self.values = sf01(self.values)
+ self.neglogpacs = sf01(self.neglogpacs)
+
+
+class PPOModel:
+ """
+ PPO Model
+ """
+ def __init__(self, model_config, mask):
+ self.model_config = model_config
+ self.states = None # initial state of lstm in policy/value network
+ self.nupdates = None # the number of func train is invoked, used to tune lr and cliprange
+ self.cur_update = 1 # record the current update
+ self.np_mask = mask # record the mask of each action within one trial
+
+ set_global_seeds(None)
+ assert isinstance(self.model_config.lr, float)
+ self.lr = constfn(self.model_config.lr)
+ assert isinstance(self.model_config.cliprange, float)
+ self.cliprange = constfn(self.model_config.cliprange)
+
+ # build lstm policy network, value share the same network
+ policy = build_lstm_policy(model_config)
+
+ # Get the nb of env
+ nenvs = model_config.num_envs
+
+ # Calculate the batch_size
+ self.nbatch = nbatch = nenvs * model_config.nsteps # num of record per update
+ nbatch_train = nbatch // model_config.nminibatches # get batch size
+ # self.nupdates is used to tune lr and cliprange
+ self.nupdates = self.model_config.total_timesteps // self.nbatch
+
+ # Instantiate the model object (that creates act_model and train_model)
+ self.model = Model(policy=policy, nbatch_act=nenvs, nbatch_train=nbatch_train,
+ nsteps=model_config.nsteps, ent_coef=model_config.ent_coef, vf_coef=model_config.vf_coef,
+ max_grad_norm=model_config.max_grad_norm, np_mask=self.np_mask)
+
+ self.states = self.model.initial_state
+
+ logger.info('=== finished PPOModel initialization')
+
+ def inference(self, num):
+ """
+ generate actions along with related info from policy network.
+ observation is the action of the last step.
+
+ Parameters:
+ ----------
+ num: the number of trials to generate
+ """
+ # Here, we init the lists that will contain the mb of experiences
+ mb_obs, mb_actions, mb_values, mb_dones, mb_neglogpacs = [], [], [], [], []
+ # initial observation
+ # use the (n+1)th embedding to represent the first step action
+ first_step_ob = self.model_config.action_space.n
+ obs = [first_step_ob for _ in range(num)]
+ dones = [True for _ in range(num)]
+ states = self.states
+ # For n in range number of steps
+ for cur_step in range(self.model_config.nsteps):
+ # Given observations, get action value and neglopacs
+ # We already have self.obs because Runner superclass run self.obs[:] = env.reset() on init
+ actions, values, states, neglogpacs = self.model.step(cur_step, obs, S=states, M=dones)
+ mb_obs.append(obs.copy())
+ mb_actions.append(actions)
+ mb_values.append(values)
+ mb_neglogpacs.append(neglogpacs)
+ mb_dones.append(dones)
+
+ # Take actions in env and look the results
+ # Infos contains a ton of useful informations
+ obs[:] = actions
+ if cur_step == self.model_config.nsteps - 1:
+ dones = [True for _ in range(num)]
+ else:
+ dones = [False for _ in range(num)]
+
+ #batch of steps to batch of rollouts
+ np_obs = np.asarray(obs)
+ mb_obs = np.asarray(mb_obs, dtype=np_obs.dtype)
+ mb_actions = np.asarray(mb_actions)
+ mb_values = np.asarray(mb_values, dtype=np.float32)
+ mb_neglogpacs = np.asarray(mb_neglogpacs, dtype=np.float32)
+ mb_dones = np.asarray(mb_dones, dtype=np.bool)
+ last_values = self.model.value(np_obs, S=states, M=dones)
+
+ return mb_obs, mb_actions, mb_values, mb_neglogpacs, mb_dones, last_values
+
+ def compute_rewards(self, trials_info, trials_result):
+ """
+ compute the rewards of the trials in trials_info based on trials_result,
+ and update the rewards in trials_info
+
+ Parameters:
+ ----------
+ trials_info: info of the generated trials
+ trials_result: final results (e.g., acc) of the generated trials
+ """
+ mb_rewards = np.asarray([trials_result for _ in trials_info.actions], dtype=np.float32)
+ # discount/bootstrap off value fn
+ mb_returns = np.zeros_like(mb_rewards)
+ mb_advs = np.zeros_like(mb_rewards)
+ lastgaelam = 0
+ last_dones = np.asarray([True for _ in trials_result], dtype=np.bool) # ugly
+ for t in reversed(range(self.model_config.nsteps)):
+ if t == self.model_config.nsteps - 1:
+ nextnonterminal = 1.0 - last_dones
+ nextvalues = trials_info.last_value
+ else:
+ nextnonterminal = 1.0 - trials_info.dones[t+1]
+ nextvalues = trials_info.values[t+1]
+ delta = mb_rewards[t] + self.model_config.gamma * nextvalues * nextnonterminal - trials_info.values[t]
+ mb_advs[t] = lastgaelam = delta + self.model_config.gamma * self.model_config.lam * nextnonterminal * lastgaelam
+ mb_returns = mb_advs + trials_info.values
+
+ trials_info.update_rewards(mb_rewards, mb_returns)
+ trials_info.convert_shape()
+
+ def train(self, trials_info, nenvs):
+ """
+ train the policy/value network using trials_info
+
+ Parameters:
+ ----------
+ trials_info: complete info of the generated trials from the previous inference
+ nenvs: the batch size of the (previous) inference
+ """
+ # keep frac decay for future optimization
+ if self.cur_update <= self.nupdates:
+ frac = 1.0 - (self.cur_update - 1.0) / self.nupdates
+ else:
+ logger.warning('current update (self.cur_update) %d has exceeded total updates (self.nupdates) %d',
+ self.cur_update, self.nupdates)
+ frac = 1.0 - (self.nupdates - 1.0) / self.nupdates
+ lrnow = self.lr(frac)
+ cliprangenow = self.cliprange(frac)
+ self.cur_update += 1
+
+ states = self.states
+
+ assert states is not None # recurrent version
+ assert nenvs % self.model_config.nminibatches == 0
+ envsperbatch = nenvs // self.model_config.nminibatches
+ envinds = np.arange(nenvs)
+ flatinds = np.arange(nenvs * self.model_config.nsteps).reshape(nenvs, self.model_config.nsteps)
+ for _ in range(self.model_config.noptepochs):
+ np.random.shuffle(envinds)
+ for start in range(0, nenvs, envsperbatch):
+ end = start + envsperbatch
+ mbenvinds = envinds[start:end]
+ mbflatinds = flatinds[mbenvinds].ravel()
+ slices = (arr[mbflatinds] for arr in (trials_info.obs, trials_info.returns, trials_info.dones,
+ trials_info.actions, trials_info.values, trials_info.neglogpacs))
+ mbstates = states[mbenvinds]
+ self.model.train(lrnow, cliprangenow, *slices, mbstates)
+
+
+class PPOTuner(Tuner):
+ """
+ PPOTuner
+ """
+
+ def __init__(self, optimize_mode, trials_per_update=20, epochs_per_update=4, minibatch_size=4,
+ ent_coef=0.0, lr=3e-4, vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95, cliprange=0.2):
+ """
+ initialization, PPO model is not initialized here as search space is not received yet.
+
+ Parameters:
+ ----------
+ optimize_mode: maximize or minimize
+ trials_per_update: number of trials to have for each model update
+ epochs_per_update: number of epochs to run for each model update
+ minibatch_size: minibatch size (number of trials) for the update
+ ent_coef: policy entropy coefficient in the optimization objective
+ lr: learning rate of the model (lstm network), constant
+ vf_coef: value function loss coefficient in the optimization objective
+ max_grad_norm: gradient norm clipping coefficient
+ gamma: discounting factor
+ lam: advantage estimation discounting factor (lambda in the paper)
+ cliprange: cliprange in the PPO algorithm, constant
+ """
+ self.optimize_mode = OptimizeMode(optimize_mode)
+ self.model_config = ModelConfig()
+ self.model = None
+ self.search_space = None
+ self.running_trials = {} # key: parameter_id, value: actions/states/etc.
+ self.inf_batch_size = trials_per_update # number of trials to generate in one inference
+ self.first_inf = True # indicate whether it is the first time to inference new trials
+ self.trials_result = [None for _ in range(self.inf_batch_size)] # results of finished trials
+
+ self.credit = 0 # record the unsatisfied trial requests
+ self.param_ids = []
+ self.finished_trials = 0
+ self.chosen_arch_template = {}
+
+ self.actions_spaces = None
+ self.actions_to_config = None
+ self.full_act_space = None
+ self.trials_info = None
+
+ self.all_trials = {} # used to dedup the same trial, key: config, value: final result
+
+ self.model_config.num_envs = self.inf_batch_size
+ self.model_config.noptepochs = epochs_per_update
+ self.model_config.nminibatches = minibatch_size
+
+ self.send_trial_callback = None
+ logger.info('=== finished PPOTuner initialization')
+
+ def _process_one_nas_space(self, block_name, block_space):
+ """
+ process nas space to determine observation space and action space
+
+ Parameters:
+ ----------
+ block_name: the name of the mutable block
+ block_space: search space of this mutable block
+
+ Returns:
+ ----------
+ actions_spaces: list of the space of each action
+ actions_to_config: the mapping from action to generated configuration
+ """
+ actions_spaces = []
+ actions_to_config = []
+
+ block_arch_temp = {}
+ for l_name, layer in block_space.items():
+ chosen_layer_temp = {}
+
+ if len(layer['layer_choice']) > 1:
+ actions_spaces.append(layer['layer_choice'])
+ actions_to_config.append((block_name, l_name, 'chosen_layer'))
+ chosen_layer_temp['chosen_layer'] = None
+ else:
+ assert len(layer['layer_choice']) == 1
+ chosen_layer_temp['chosen_layer'] = layer['layer_choice'][0]
+
+ if layer['optional_input_size'] not in [0, 1, [0, 1]]:
+ raise ValueError('Optional_input_size can only be 0, 1, or [0, 1], but the pecified one is %s'
+ % (layer['optional_input_size']))
+ if isinstance(layer['optional_input_size'], list):
+ actions_spaces.append(["None", *layer['optional_inputs']])
+ actions_to_config.append((block_name, l_name, 'chosen_inputs'))
+ chosen_layer_temp['chosen_inputs'] = None
+ elif layer['optional_input_size'] == 1:
+ actions_spaces.append(layer['optional_inputs'])
+ actions_to_config.append((block_name, l_name, 'chosen_inputs'))
+ chosen_layer_temp['chosen_inputs'] = None
+ elif layer['optional_input_size'] == 0:
+ chosen_layer_temp['chosen_inputs'] = []
+ else:
+ raise ValueError('invalid type and value of optional_input_size')
+
+ block_arch_temp[l_name] = chosen_layer_temp
+
+ self.chosen_arch_template[block_name] = block_arch_temp
+
+ return actions_spaces, actions_to_config
+
+ def _process_nas_space(self, search_space):
+ """
+ process nas search space to get action/observation space
+ """
+ actions_spaces = []
+ actions_to_config = []
+ for b_name, block in search_space.items():
+ if block['_type'] != 'mutable_layer':
+ raise ValueError('PPOTuner only accept mutable_layer type in search space, but the current one is %s'%(block['_type']))
+ block = block['_value']
+ act, act_map = self._process_one_nas_space(b_name, block)
+ actions_spaces.extend(act)
+ actions_to_config.extend(act_map)
+
+ # calculate observation space
+ dedup = {}
+ for step in actions_spaces:
+ for action in step:
+ dedup[action] = 1
+ full_act_space = [act for act, _ in dedup.items()]
+ assert len(full_act_space) == len(dedup)
+ observation_space = len(full_act_space)
+
+ nsteps = len(actions_spaces)
+
+ return actions_spaces, actions_to_config, full_act_space, observation_space, nsteps
+
+ def _generate_action_mask(self):
+ """
+ different step could have different action space. to deal with this case, we merge all the
+ possible actions into one action space, and use mask to indicate available actions for each step
+ """
+ two_masks = []
+
+ mask = []
+ for acts in self.actions_spaces:
+ one_mask = [0 for _ in range(len(self.full_act_space))]
+ for act in acts:
+ idx = self.full_act_space.index(act)
+ one_mask[idx] = 1
+ mask.append(one_mask)
+ two_masks.append(mask)
+
+ mask = []
+ for acts in self.actions_spaces:
+ one_mask = [-np.inf for _ in range(len(self.full_act_space))]
+ for act in acts:
+ idx = self.full_act_space.index(act)
+ one_mask[idx] = 0
+ mask.append(one_mask)
+ two_masks.append(mask)
+
+ return np.asarray(two_masks, dtype=np.float32)
+
+ def update_search_space(self, search_space):
+ """
+ get search space, currently the space only includes that for NAS
+
+ Parameters:
+ ----------
+ search_space: search space for NAS
+
+ Returns:
+ -------
+ no return
+ """
+ logger.info('=== update search space %s', search_space)
+ assert self.search_space is None
+ self.search_space = search_space
+
+ assert self.model_config.observation_space is None
+ assert self.model_config.action_space is None
+
+ self.actions_spaces, self.actions_to_config, self.full_act_space, obs_space, nsteps = self._process_nas_space(search_space)
+
+ self.model_config.observation_space = spaces.Discrete(obs_space)
+ self.model_config.action_space = spaces.Discrete(obs_space)
+ self.model_config.nsteps = nsteps
+
+ # generate mask in numpy
+ mask = self._generate_action_mask()
+
+ assert self.model is None
+ self.model = PPOModel(self.model_config, mask)
+
+ def _actions_to_config(self, actions):
+ """
+ given actions, to generate the corresponding trial configuration
+ """
+ chosen_arch = copy.deepcopy(self.chosen_arch_template)
+ for cnt, act in enumerate(actions):
+ act_name = self.full_act_space[act]
+ (block_name, layer_name, key) = self.actions_to_config[cnt]
+ if key == 'chosen_inputs':
+ if act_name == 'None':
+ chosen_arch[block_name][layer_name][key] = []
+ else:
+ chosen_arch[block_name][layer_name][key] = [act_name]
+ elif key == 'chosen_layer':
+ chosen_arch[block_name][layer_name][key] = act_name
+ else:
+ raise ValueError('unrecognized key: {0}'.format(key))
+ return chosen_arch
+
+ def generate_multiple_parameters(self, parameter_id_list, **kwargs):
+ """
+ Returns multiple sets of trial (hyper-)parameters, as iterable of serializable objects.
+ """
+ result = []
+ self.send_trial_callback = kwargs['st_callback']
+ for parameter_id in parameter_id_list:
+ had_exception = False
+ try:
+ logger.debug("generating param for %s", parameter_id)
+ res = self.generate_parameters(parameter_id, **kwargs)
+ except nni.NoMoreTrialError:
+ had_exception = True
+ if not had_exception:
+ result.append(res)
+ return result
+
+ def generate_parameters(self, parameter_id, **kwargs):
+ """
+ generate parameters, if no trial configration for now, self.credit plus 1 to send the config later
+ """
+ if self.first_inf:
+ self.trials_result = [None for _ in range(self.inf_batch_size)]
+ mb_obs, mb_actions, mb_values, mb_neglogpacs, mb_dones, last_values = self.model.inference(self.inf_batch_size)
+ self.trials_info = TrialsInfo(mb_obs, mb_actions, mb_values, mb_neglogpacs,
+ mb_dones, last_values, self.inf_batch_size)
+ self.first_inf = False
+
+ trial_info_idx, actions = self.trials_info.get_next()
+ if trial_info_idx is None:
+ self.credit += 1
+ self.param_ids.append(parameter_id)
+ raise nni.NoMoreTrialError('no more parameters now.')
+
+ self.running_trials[parameter_id] = trial_info_idx
+ new_config = self._actions_to_config(actions)
+ return new_config
+
+ def _next_round_inference(self):
+ """
+ """
+ self.finished_trials = 0
+ self.model.compute_rewards(self.trials_info, self.trials_result)
+ self.model.train(self.trials_info, self.inf_batch_size)
+ self.running_trials = {}
+ # generate new trials
+ self.trials_result = [None for _ in range(self.inf_batch_size)]
+ mb_obs, mb_actions, mb_values, mb_neglogpacs, mb_dones, last_values = self.model.inference(self.inf_batch_size)
+ self.trials_info = TrialsInfo(mb_obs, mb_actions, mb_values, mb_neglogpacs,
+ mb_dones, last_values, self.inf_batch_size)
+ # check credit and submit new trials
+ for _ in range(self.credit):
+ trial_info_idx, actions = self.trials_info.get_next()
+ if trial_info_idx is None:
+ logger.warning('No enough trial config, trials_per_update is suggested to be larger than trialConcurrency')
+ break
+ assert self.param_ids
+ param_id = self.param_ids.pop()
+ self.running_trials[param_id] = trial_info_idx
+ new_config = self._actions_to_config(actions)
+ self.send_trial_callback(param_id, new_config)
+ self.credit -= 1
+
+ def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
+ """
+ receive trial's result. if the number of finished trials equals self.inf_batch_size, start the next update to
+ train the model
+ """
+ trial_info_idx = self.running_trials.pop(parameter_id, None)
+ assert trial_info_idx is not None
+
+ value = extract_scalar_reward(value)
+ if self.optimize_mode == OptimizeMode.Minimize:
+ value = -value
+
+ self.trials_result[trial_info_idx] = value
+ self.finished_trials += 1
+
+ if self.finished_trials == self.inf_batch_size:
+ self._next_round_inference()
+
+ def trial_end(self, parameter_id, success, **kwargs):
+ """
+ to deal with trial failure
+ """
+ if not success:
+ if parameter_id not in self.running_trials:
+ logger.warning('The trial is failed, but self.running_trial does not have this trial')
+ return
+ trial_info_idx = self.running_trials.pop(parameter_id, None)
+ assert trial_info_idx is not None
+ # use mean of finished trials as the result of this failed trial
+ values = [val for val in self.trials_result if val is not None]
+ logger.warning('zql values: {0}'.format(values))
+ self.trials_result[trial_info_idx] = (sum(values) / len(values)) if len(values) > 0 else 0
+ self.finished_trials += 1
+ if self.finished_trials == self.inf_batch_size:
+ self._next_round_inference()
+
+ def import_data(self, data):
+ """
+ Import additional data for tuning
+
+ Parameters
+ ----------
+ data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
+ """
+ logger.warning('PPOTuner cannot leverage imported data.')
diff --git a/src/sdk/pynni/nni/ppo_tuner/requirements.txt b/src/sdk/pynni/nni/ppo_tuner/requirements.txt
new file mode 100644
index 0000000000..138951469b
--- /dev/null
+++ b/src/sdk/pynni/nni/ppo_tuner/requirements.txt
@@ -0,0 +1,3 @@
+enum34
+gym
+tensorflow
\ No newline at end of file
diff --git a/src/sdk/pynni/nni/ppo_tuner/util.py b/src/sdk/pynni/nni/ppo_tuner/util.py
new file mode 100644
index 0000000000..ac958e54de
--- /dev/null
+++ b/src/sdk/pynni/nni/ppo_tuner/util.py
@@ -0,0 +1,266 @@
+# Copyright (c) Microsoft Corporation
+# All rights reserved.
+#
+# MIT License
+#
+# Permission is hereby granted, free of charge,
+# to any person obtaining a copy of this software and associated
+# documentation files (the "Software"), to deal in the Software without restriction,
+# including without limitation the rights to use, copy, modify, merge, publish,
+# distribute, sublicense, and/or sell copies of the Software, and
+# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+# The above copyright notice and this permission notice shall be included
+# in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
+# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
+# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+"""
+util functions
+"""
+
+import os
+import random
+import multiprocessing
+import numpy as np
+import tensorflow as tf
+from gym.spaces import Discrete, Box, MultiDiscrete
+
+def set_global_seeds(i):
+ """set global seeds"""
+ rank = 0
+ myseed = i + 1000 * rank if i is not None else None
+ tf.set_random_seed(myseed)
+ np.random.seed(myseed)
+ random.seed(myseed)
+
+def batch_to_seq(h, nbatch, nsteps, flat=False):
+ """convert from batch to sequence"""
+ if flat:
+ h = tf.reshape(h, [nbatch, nsteps])
+ else:
+ h = tf.reshape(h, [nbatch, nsteps, -1])
+ return [tf.squeeze(v, [1]) for v in tf.split(axis=1, num_or_size_splits=nsteps, value=h)]
+
+def seq_to_batch(h, flat=False):
+ """convert from sequence to batch"""
+ shape = h[0].get_shape().as_list()
+ if not flat:
+ assert len(shape) > 1
+ nh = h[0].get_shape()[-1].value
+ return tf.reshape(tf.concat(axis=1, values=h), [-1, nh])
+ else:
+ return tf.reshape(tf.stack(values=h, axis=1), [-1])
+
+def lstm(xs, ms, s, scope, nh, init_scale=1.0):
+ """lstm cell"""
+ nbatch, nin = [v.value for v in xs[0].get_shape()]
+ with tf.variable_scope(scope):
+ wx = tf.get_variable("wx", [nin, nh*4], initializer=ortho_init(init_scale))
+ wh = tf.get_variable("wh", [nh, nh*4], initializer=ortho_init(init_scale))
+ b = tf.get_variable("b", [nh*4], initializer=tf.constant_initializer(0.0))
+
+ c, h = tf.split(axis=1, num_or_size_splits=2, value=s)
+ for idx, (x, m) in enumerate(zip(xs, ms)):
+ c = c*(1-m)
+ h = h*(1-m)
+ z = tf.matmul(x, wx) + tf.matmul(h, wh) + b
+ i, f, o, u = tf.split(axis=1, num_or_size_splits=4, value=z)
+ i = tf.nn.sigmoid(i)
+ f = tf.nn.sigmoid(f)
+ o = tf.nn.sigmoid(o)
+ u = tf.tanh(u)
+ c = f*c + i*u
+ h = o*tf.tanh(c)
+ xs[idx] = h
+ s = tf.concat(axis=1, values=[c, h])
+ return xs, s
+
+def lstm_model(nlstm=128, layer_norm=False):
+ """
+ Builds LSTM (Long-Short Term Memory) network to be used in a policy.
+ Note that the resulting function returns not only the output of the LSTM
+ (i.e. hidden state of lstm for each step in the sequence), but also a dictionary
+ with auxiliary tensors to be set as policy attributes.
+
+ Specifically,
+ S is a placeholder to feed current state (LSTM state has to be managed outside policy)
+ M is a placeholder for the mask (used to mask out observations after the end of the episode, but can be used for other purposes too)
+ initial_state is a numpy array containing initial lstm state (usually zeros)
+ state is the output LSTM state (to be fed into S at the next call)
+
+
+ An example of usage of lstm-based policy can be found here: common/tests/test_doc_examples.py/test_lstm_example
+
+ Parameters:
+ ----------
+ nlstm: int LSTM hidden state size
+ layer_norm: bool if True, layer-normalized version of LSTM is used
+
+ Returns:
+ -------
+ function that builds LSTM with a given input tensor / placeholder
+ """
+
+ def network_fn(X, nenv=1, obs_size=-1):
+ with tf.variable_scope("emb", reuse=tf.AUTO_REUSE):
+ w_emb = tf.get_variable("w_emb", [obs_size+1, 32])
+ X = tf.nn.embedding_lookup(w_emb, X)
+
+ nbatch = X.shape[0]
+ nsteps = nbatch // nenv
+
+ h = tf.layers.flatten(X)
+
+ M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
+ S = tf.placeholder(tf.float32, [nenv, 2*nlstm]) #states
+
+ xs = batch_to_seq(h, nenv, nsteps)
+ ms = batch_to_seq(M, nenv, nsteps)
+
+ assert not layer_norm
+ h5, snew = lstm(xs, ms, S, scope='lstm', nh=nlstm)
+
+ h = seq_to_batch(h5)
+ initial_state = np.zeros(S.shape.as_list(), dtype=float)
+
+ return h, {'S':S, 'M':M, 'state':snew, 'initial_state':initial_state}
+
+ return network_fn
+
+def ortho_init(scale=1.0):
+ """init approach"""
+ def _ortho_init(shape, dtype, partition_info=None):
+ #lasagne ortho init for tf
+ shape = tuple(shape)
+ if len(shape) == 2:
+ flat_shape = shape
+ elif len(shape) == 4: # assumes NHWC
+ flat_shape = (np.prod(shape[:-1]), shape[-1])
+ else:
+ raise NotImplementedError
+ a = np.random.normal(0.0, 1.0, flat_shape)
+ u, _, v = np.linalg.svd(a, full_matrices=False)
+ q = u if u.shape == flat_shape else v # pick the one with the correct shape
+ q = q.reshape(shape)
+ return (scale * q[:shape[0], :shape[1]]).astype(np.float32)
+ return _ortho_init
+
+def fc(x, scope, nh, *, init_scale=1.0, init_bias=0.0):
+ """fully connected op"""
+ with tf.variable_scope(scope):
+ nin = x.get_shape()[1].value
+ w = tf.get_variable("w", [nin, nh], initializer=ortho_init(init_scale))
+ b = tf.get_variable("b", [nh], initializer=tf.constant_initializer(init_bias))
+ return tf.matmul(x, w)+b
+
+def _check_shape(placeholder_shape, data_shape):
+ """
+ check if two shapes are compatible (i.e. differ only by dimensions of size 1, or by the batch dimension)
+ """
+
+ return True
+
+# ================================================================
+# Shape adjustment for feeding into tf placeholders
+# ================================================================
+def adjust_shape(placeholder, data):
+ """
+ adjust shape of the data to the shape of the placeholder if possible.
+ If shape is incompatible, AssertionError is thrown
+
+ Parameters:
+ placeholder: tensorflow input placeholder
+ data: input data to be (potentially) reshaped to be fed into placeholder
+
+ Returns:
+ reshaped data
+ """
+ if not isinstance(data, np.ndarray) and not isinstance(data, list):
+ return data
+ if isinstance(data, list):
+ data = np.array(data)
+
+ placeholder_shape = [x or -1 for x in placeholder.shape.as_list()]
+
+ assert _check_shape(placeholder_shape, data.shape), \
+ 'Shape of data {} is not compatible with shape of the placeholder {}'.format(data.shape, placeholder_shape)
+
+ return np.reshape(data, placeholder_shape)
+
+# ================================================================
+# Global session
+# ================================================================
+
+def get_session(config=None):
+ """Get default session or create one with a given config"""
+ sess = tf.get_default_session()
+ if sess is None:
+ sess = make_session(config=config, make_default=True)
+ return sess
+
+def make_session(config=None, num_cpu=None, make_default=False, graph=None):
+ """Returns a session that will use CPU's only"""
+ if num_cpu is None:
+ num_cpu = int(os.getenv('RCALL_NUM_CPU', multiprocessing.cpu_count()))
+ if config is None:
+ config = tf.ConfigProto(
+ allow_soft_placement=True,
+ inter_op_parallelism_threads=num_cpu,
+ intra_op_parallelism_threads=num_cpu)
+ config.gpu_options.allow_growth = True
+
+ if make_default:
+ return tf.InteractiveSession(config=config, graph=graph)
+ else:
+ return tf.Session(config=config, graph=graph)
+
+ALREADY_INITIALIZED = set()
+
+def initialize():
+ """Initialize all the uninitialized variables in the global scope."""
+ new_variables = set(tf.global_variables()) - ALREADY_INITIALIZED
+ get_session().run(tf.variables_initializer(new_variables))
+
+ ALREADY_INITIALIZED.update(new_variables)
+
+def observation_placeholder(ob_space, batch_size=None, name='Ob'):
+ """
+ Create placeholder to feed observations into of the size appropriate to the observation space
+
+ Parameters:
+ ----------
+ ob_space: gym.Space observation space
+ batch_size: int size of the batch to be fed into input. Can be left None in most cases.
+ name: str name of the placeholder
+
+ Returns:
+ -------
+ tensorflow placeholder tensor
+ """
+
+ assert isinstance(ob_space, (Discrete, Box, MultiDiscrete)), \
+ 'Can only deal with Discrete and Box observation spaces for now'
+
+ dtype = ob_space.dtype
+ if dtype == np.int8:
+ dtype = np.uint8
+
+ return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name)
+
+def explained_variance(ypred, y):
+ """
+ Computes fraction of variance that ypred explains about y.
+ Returns 1 - Var[y-ypred] / Var[y]
+
+ interpretation:
+ ev=0 => might as well have predicted zero
+ ev=1 => perfect prediction
+ ev<0 => worse than just predicting zero
+
+ """
+ assert y.ndim == 1 and ypred.ndim == 1
+ vary = np.var(y)
+ return np.nan if vary == 0 else 1 - np.var(y-ypred)/vary
diff --git a/src/sdk/pynni/nni/trial.py b/src/sdk/pynni/nni/trial.py
index 132fd96834..89ceeb4a49 100644
--- a/src/sdk/pynni/nni/trial.py
+++ b/src/sdk/pynni/nni/trial.py
@@ -43,7 +43,8 @@
def get_next_parameter():
- """Returns a set of (hyper-)paremeters generated by Tuner."""
+ """Returns a set of (hyper-)paremeters generated by Tuner.
+ Returns None if no more (hyper-)parameters can be generated by Tuner."""
global _params
_params = platform.get_next_parameter()
if _params is None:
diff --git a/src/sdk/pynni/nni/tuner.py b/src/sdk/pynni/nni/tuner.py
index 0d995f94c5..c9c72d479b 100644
--- a/src/sdk/pynni/nni/tuner.py
+++ b/src/sdk/pynni/nni/tuner.py
@@ -17,11 +17,10 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
-
-
import logging
import nni
+
from .recoverable import Recoverable
_logger = logging.getLogger(__name__)
@@ -57,19 +56,22 @@ def generate_multiple_parameters(self, parameter_id_list, **kwargs):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
"""Invoked when a trial reports its final result. Must override.
+ By default this only reports results of algorithm-generated hyper-parameters.
+ Use `accept_customized_trials()` to receive results from user-added parameters.
parameter_id: int
parameters: object created by 'generate_parameters()'
- reward: object reported by trial
+ value: object reported by trial
+ customized: bool, true if the trial is created from web UI, false if generated by algorithm
+ trial_job_id: str, only available in multiphase mode.
"""
raise NotImplementedError('Tuner: receive_trial_result not implemented')
- def receive_customized_trial_result(self, parameter_id, parameters, value, **kwargs):
- """Invoked when a trial added by WebUI reports its final result. Do nothing by default.
- parameter_id: int
- parameters: object created by user
- value: object reported by trial
+ def accept_customized_trials(self, accept=True):
+ """Enable or disable receiving results of user-added hyper-parameters.
+ By default `receive_trial_result()` will only receive results of algorithm-generated hyper-parameters.
+ If tuners want to receive those of customized parameters as well, they can call this function in `__init__()`.
"""
- _logger.info('Customized trial job %s ignored by tuner', parameter_id)
+ self._accept_customized = accept
def trial_end(self, parameter_id, success, **kwargs):
"""Invoked when a trial is completed or terminated. Do nothing by default.
diff --git a/src/sdk/pynni/tests/test_tuner.py b/src/sdk/pynni/tests/test_tuner.py
index 41e80cfa6d..c1fd3594ee 100644
--- a/src/sdk/pynni/tests/test_tuner.py
+++ b/src/sdk/pynni/tests/test_tuner.py
@@ -34,6 +34,7 @@ def __init__(self):
self.param = 0
self.trial_results = [ ]
self.search_space = None
+ self.accept_customized_trials()
def generate_parameters(self, parameter_id, **kwargs):
# report Tuner's internal states to generated parameters,
@@ -45,13 +46,9 @@ def generate_parameters(self, parameter_id, **kwargs):
'search_space': self.search_space
}
- def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
+ def receive_trial_result(self, parameter_id, parameters, value, customized, **kwargs):
reward = extract_scalar_reward(value)
- self.trial_results.append((parameter_id, parameters['param'], reward, False))
-
- def receive_customized_trial_result(self, parameter_id, parameters, value):
- reward = extract_scalar_reward(value)
- self.trial_results.append((parameter_id, parameters['param'], reward, True))
+ self.trial_results.append((parameter_id, parameters['param'], reward, customized))
def update_search_space(self, search_space):
self.search_space = search_space
diff --git a/src/webui/src/App.tsx b/src/webui/src/App.tsx
index c3b31d422a..b55129d6e3 100644
--- a/src/webui/src/App.tsx
+++ b/src/webui/src/App.tsx
@@ -1,104 +1,111 @@
import * as React from 'react';
import { Row, Col } from 'antd';
-import axios from 'axios';
-import { COLUMN, MANAGER_IP } from './static/const';
+import { COLUMN } from './static/const';
+import { EXPERIMENT, TRIALS } from './static/datamodel';
import './App.css';
import SlideBar from './components/SlideBar';
interface AppState {
- interval: number;
- whichPageToFresh: string;
- columnList: Array;
- concurrency: number;
+ interval: number;
+ columnList: Array;
+ experimentUpdateBroadcast: number;
+ trialsUpdateBroadcast: number;
}
class App extends React.Component<{}, AppState> {
- public _isMounted: boolean;
- constructor(props: {}) {
- super(props);
- this.state = {
- interval: 10, // sendons
- whichPageToFresh: '',
- columnList: COLUMN,
- concurrency: 1
- };
- }
+ private timerId: number | null;
- changeInterval = (interval: number) => {
- if (this._isMounted === true) {
- this.setState(() => ({ interval: interval }));
+ constructor(props: {}) {
+ super(props);
+ this.state = {
+ interval: 10, // sendons
+ columnList: COLUMN,
+ experimentUpdateBroadcast: 0,
+ trialsUpdateBroadcast: 0,
+ };
}
- }
- changeFresh = (fresh: string) => {
- // interval * 1000
- if (this._isMounted === true) {
- this.setState(() => ({ whichPageToFresh: fresh }));
+ async componentDidMount() {
+ await Promise.all([ EXPERIMENT.init(), TRIALS.init() ]);
+ this.setState(state => ({ experimentUpdateBroadcast: state.experimentUpdateBroadcast + 1 }));
+ this.setState(state => ({ trialsUpdateBroadcast: state.trialsUpdateBroadcast + 1 }));
+ this.timerId = window.setTimeout(this.refresh, this.state.interval * 1000);
}
- }
- changeColumn = (columnList: Array) => {
- if (this._isMounted === true) {
- this.setState(() => ({ columnList: columnList }));
+ changeInterval = (interval: number) => {
+ this.setState({ interval: interval });
+ if (this.timerId === null && interval !== 0) {
+ window.setTimeout(this.refresh);
+ } else if (this.timerId !== null && interval === 0) {
+ window.clearTimeout(this.timerId);
+ }
}
- }
- changeConcurrency = (val: number) => {
- if (this._isMounted === true) {
- this.setState(() => ({ concurrency: val }));
+ // TODO: use local storage
+ changeColumn = (columnList: Array) => {
+ this.setState({ columnList: columnList });
}
- }
- getConcurrency = () => {
- axios(`${MANAGER_IP}/experiment`, {
- method: 'GET'
- })
- .then(res => {
- if (res.status === 200) {
- const params = res.data.params;
- if (this._isMounted) {
- this.setState(() => ({ concurrency: params.trialConcurrency }));
- }
+ render() {
+ const { interval, columnList, experimentUpdateBroadcast, trialsUpdateBroadcast } = this.state;
+ if (experimentUpdateBroadcast === 0 || trialsUpdateBroadcast === 0) {
+ return null; // TODO: render a loading page
+ }
+ const reactPropsChildren = React.Children.map(this.props.children, child =>
+ React.cloneElement(
+ // tslint:disable-next-line:no-any
+ child as React.ReactElement, {
+ interval,
+ columnList, changeColumn: this.changeColumn,
+ experimentUpdateBroadcast,
+ trialsUpdateBroadcast,
+ })
+ );
+ return (
+
+
+
+
+
+
+
+
+
+
+ {reactPropsChildren}
+
+
+
+ );
+ }
+
+ private refresh = async () => {
+ const [ experimentUpdated, trialsUpdated ] = await Promise.all([ EXPERIMENT.update(), TRIALS.update() ]);
+ if (experimentUpdated) {
+ this.setState(state => ({ experimentUpdateBroadcast: state.experimentUpdateBroadcast + 1 }));
+ }
+ if (trialsUpdated) {
+ this.setState(state => ({ trialsUpdateBroadcast: state.trialsUpdateBroadcast + 1 }));
}
- });
- }
- componentDidMount() {
- this._isMounted = true;
- this.getConcurrency();
- }
+ if ([ 'DONE', 'ERROR', 'STOPPED' ].includes(EXPERIMENT.status)) {
+ // experiment finished, refresh once more to ensure consistency
+ if (this.state.interval > 0) {
+ this.setState({ interval: 0 });
+ this.lastRefresh();
+ }
- componentWillUnmount() {
- this._isMounted = false;
- }
- render() {
- const { interval, whichPageToFresh, columnList, concurrency } = this.state;
- const reactPropsChildren = React.Children.map(this.props.children, child =>
- React.cloneElement(
- // tslint:disable-next-line:no-any
- child as React.ReactElement, {
- interval, whichPageToFresh,
- columnList, changeColumn: this.changeColumn,
- concurrency, changeConcurrency: this.changeConcurrency
- })
- );
- return (
-
-
-
-
-
-
-
-
-
-
- {reactPropsChildren}
-
-
-
- );
- }
+ } else if (this.state.interval !== 0) {
+ this.timerId = window.setTimeout(this.refresh, this.state.interval * 1000);
+ }
+ }
+
+ private async lastRefresh() {
+ await EXPERIMENT.update();
+ await TRIALS.update(true);
+ this.setState(state => ({ experimentUpdateBroadcast: state.experimentUpdateBroadcast + 1 }));
+ this.setState(state => ({ trialsUpdateBroadcast: state.trialsUpdateBroadcast + 1 }));
+ }
}
export default App;
diff --git a/src/webui/src/components/Modal/Compare.tsx b/src/webui/src/components/Modal/Compare.tsx
index ec0c3a9c01..8a7278c966 100644
--- a/src/webui/src/components/Modal/Compare.tsx
+++ b/src/webui/src/components/Modal/Compare.tsx
@@ -2,12 +2,13 @@ import * as React from 'react';
import { Row, Modal } from 'antd';
import ReactEcharts from 'echarts-for-react';
import IntermediateVal from '../public-child/IntermediateVal';
+import { TRIALS } from '../../static/datamodel';
import '../../static/style/compare.scss';
-import { TableObj, Intermedia, TooltipForIntermediate } from 'src/static/interface';
+import { TableRecord, Intermedia, TooltipForIntermediate } from 'src/static/interface';
// the modal of trial compare
interface CompareProps {
- compareRows: Array;
+ compareRows: Array;
visible: boolean;
cancelFunc: () => void;
}
@@ -25,11 +26,12 @@ class Compare extends React.Component {
const idsList: Array = [];
Object.keys(compareRows).map(item => {
const temp = compareRows[item];
+ const trial = TRIALS.getTrial(temp.id);
trialIntermediate.push({
name: temp.id,
- data: temp.description.intermediate,
+ data: trial.description.intermediate,
type: 'line',
- hyperPara: temp.description.parameters
+ hyperPara: trial.description.parameters
});
idsList.push(temp.id);
});
@@ -105,10 +107,12 @@ class Compare extends React.Component {
// render table column ---
initColumn = () => {
- const { compareRows } = this.props;
const idList: Array = [];
+ const sequenceIdList: Array = [];
const durationList: Array = [];
+ const compareRows = this.props.compareRows.map(tableRecord => TRIALS.getTrial(tableRecord.id));
+
const parameterList: Array = [];
let parameterKeys: Array = [];
if (compareRows.length !== 0) {
@@ -117,6 +121,7 @@ class Compare extends React.Component {
Object.keys(compareRows).map(item => {
const temp = compareRows[item];
idList.push(temp.id);
+ sequenceIdList.push(temp.sequenceId);
durationList.push(temp.duration);
parameterList.push(temp.description.parameters);
});
@@ -124,20 +129,28 @@ class Compare extends React.Component {
-
+ Id
{Object.keys(idList).map(key => {
return (
{idList[key]}
);
})}
+
+ Trial No.
+ {Object.keys(sequenceIdList).map(key => {
+ return (
+ {sequenceIdList[key]}
+ );
+ })}
+
Default metric
{Object.keys(compareRows).map(index => {
const temp = compareRows[index];
return (
-
+
);
})}
@@ -196,7 +209,7 @@ class Compare extends React.Component {
>
{this.intermediate()}
- # Intermediate
+ # Intermediate result
{this.initColumn()}
diff --git a/src/webui/src/components/Modal/ExperimentDrawer.tsx b/src/webui/src/components/Modal/ExperimentDrawer.tsx
index 2433eec439..2541811bf3 100644
--- a/src/webui/src/components/Modal/ExperimentDrawer.tsx
+++ b/src/webui/src/components/Modal/ExperimentDrawer.tsx
@@ -58,7 +58,7 @@ class ExperimentDrawer extends React.Component {
trialMessage: trialMessagesArr
};
if (this._isCompareMount === true) {
- this.setState(() => ({ experiment: JSON.stringify(result, null, 4) }));
+ this.setState({ experiment: JSON.stringify(result, null, 4) });
}
}
}));
diff --git a/src/webui/src/components/Modal/LogDrawer.tsx b/src/webui/src/components/Modal/LogDrawer.tsx
index 89a9e90798..bd4abcdb18 100644
--- a/src/webui/src/components/Modal/LogDrawer.tsx
+++ b/src/webui/src/components/Modal/LogDrawer.tsx
@@ -51,13 +51,13 @@ class LogDrawer extends React.Component {
setDispatcher = (value: string) => {
if (this._isLogDrawer === true) {
- this.setState(() => ({ isLoadispatcher: false, dispatcherLogStr: value }));
+ this.setState({ isLoadispatcher: false, dispatcherLogStr: value });
}
}
setNNImanager = (val: string) => {
if (this._isLogDrawer === true) {
- this.setState(() => ({ isLoading: false, nniManagerLogStr: val }));
+ this.setState({ isLoading: false, nniManagerLogStr: val });
}
}
diff --git a/src/webui/src/components/Overview.tsx b/src/webui/src/components/Overview.tsx
index 7366b45511..22d52e5458 100644
--- a/src/webui/src/components/Overview.tsx
+++ b/src/webui/src/components/Overview.tsx
@@ -1,16 +1,14 @@
import * as React from 'react';
-import axios from 'axios';
import { Row, Col } from 'antd';
-import { MANAGER_IP } from '../static/const';
-import { Experiment, TableObj, Parameters, TrialNumber } from '../static/interface';
-import { getFinal } from '../static/function';
+import { EXPERIMENT, TRIALS } from '../static/datamodel';
+import { Trial } from '../static/model/trial';
import SuccessTable from './overview/SuccessTable';
import Title1 from './overview/Title1';
import Progressed from './overview/Progress';
import Accuracy from './overview/Accuracy';
import SearchSpace from './overview/SearchSpace';
import BasicInfo from './overview/BasicInfo';
-import TrialPro from './overview/TrialProfile';
+import TrialInfo from './overview/TrialProfile';
require('../static/style/overview.scss');
require('../static/style/logPath.scss');
@@ -18,486 +16,70 @@ require('../static/style/accuracy.css');
require('../static/style/table.scss');
require('../static/style/overviewTitle.scss');
-interface OverviewState {
- tableData: Array;
- experimentAPI: object;
- searchSpace: object;
- status: string;
- errorStr: string;
- trialProfile: Experiment;
- option: object;
- noData: string;
- accuracyData: object;
- bestAccuracy: number;
- accNodata: string;
- trialNumber: TrialNumber;
- isTop10: boolean;
- titleMaxbgcolor?: string;
- titleMinbgcolor?: string;
- // trial stdout is content(false) or link(true)
- isLogCollection: boolean;
- isMultiPhase: boolean;
+interface OverviewProps {
+ experimentUpdateBroadcast: number;
+ trialsUpdateBroadcast: number;
}
-interface OverviewProps {
- interval: number; // user select
- whichPageToFresh: string;
- concurrency: number;
- changeConcurrency: (val: number) => void;
+interface OverviewState {
+ trialConcurrency: number;
+ metricGraphMode: 'max' | 'min';
}
class Overview extends React.Component {
-
- public _isMounted = false;
- public intervalID = 0;
- public intervalProfile = 1;
-
constructor(props: OverviewProps) {
super(props);
this.state = {
- searchSpace: {},
- experimentAPI: {},
- status: '',
- errorStr: '',
- trialProfile: {
- id: '',
- author: '',
- experName: '',
- runConcurren: 1,
- maxDuration: 0,
- execDuration: 0,
- MaxTrialNum: 0,
- startTime: 0,
- tuner: {},
- trainingServicePlatform: ''
- },
- tableData: [],
- option: {},
- noData: '',
- // accuracy
- accuracyData: {},
- accNodata: '',
- bestAccuracy: 0,
- trialNumber: {
- succTrial: 0,
- failTrial: 0,
- stopTrial: 0,
- waitTrial: 0,
- runTrial: 0,
- unknowTrial: 0,
- totalCurrentTrial: 0
- },
- isTop10: true,
- isLogCollection: false,
- isMultiPhase: false
+ trialConcurrency: EXPERIMENT.trialConcurrency,
+ metricGraphMode: (EXPERIMENT.optimizeMode === 'minimize' ? 'min' : 'max'),
};
}
- // show session
- showSessionPro = () => {
- axios(`${MANAGER_IP}/experiment`, {
- method: 'GET'
- })
- .then(res => {
- if (res.status === 200) {
- let sessionData = res.data;
- let trialPro = [];
- const tempara = sessionData.params;
- const trainingPlatform = tempara.trainingServicePlatform;
- // assessor clusterMeteData
- const clusterMetaData = tempara.clusterMetaData;
- const endTimenum = sessionData.endTime;
- const assessor = tempara.assessor;
- const advisor = tempara.advisor;
- let optimizeMode = 'other';
- if (tempara.tuner !== undefined) {
- if (tempara.tuner.classArgs !== undefined) {
- if (tempara.tuner.classArgs.optimize_mode !== undefined) {
- optimizeMode = tempara.tuner.classArgs.optimize_mode;
- }
- }
- }
- // default logCollection is true
- const logCollection = tempara.logCollection;
- let expLogCollection: boolean = false;
- const isMultiy: boolean = tempara.multiPhase !== undefined
- ? tempara.multiPhase : false;
- if (optimizeMode !== undefined) {
- if (optimizeMode === 'minimize') {
- if (this._isMounted) {
- this.setState({
- isTop10: false,
- titleMinbgcolor: '#999'
- });
- }
- } else {
- if (this._isMounted) {
- this.setState({
- isTop10: true,
- titleMaxbgcolor: '#999'
- });
- }
- }
- }
- if (logCollection !== undefined && logCollection !== 'none') {
- expLogCollection = true;
- }
- trialPro.push({
- id: sessionData.id,
- author: tempara.authorName,
- revision: sessionData.revision,
- experName: tempara.experimentName,
- runConcurren: tempara.trialConcurrency,
- logDir: sessionData.logDir ? sessionData.logDir : 'undefined',
- maxDuration: tempara.maxExecDuration,
- execDuration: sessionData.execDuration,
- MaxTrialNum: tempara.maxTrialNum,
- startTime: sessionData.startTime,
- endTime: endTimenum ? endTimenum : undefined,
- trainingServicePlatform: trainingPlatform,
- tuner: tempara.tuner,
- assessor: assessor ? assessor : undefined,
- advisor: advisor ? advisor : undefined,
- clusterMetaData: clusterMetaData ? clusterMetaData : undefined,
- logCollection: logCollection
- });
- // search space format loguniform max and min
- const temp = tempara.searchSpace;
- const searchSpace = temp !== undefined
- ? JSON.parse(temp) : {};
- Object.keys(searchSpace).map(item => {
- const key = searchSpace[item]._type;
- let value = searchSpace[item]._value;
- switch (key) {
- case 'quniform':
- case 'qnormal':
- case 'qlognormal':
- searchSpace[item]._value = [value[0], value[1]];
- break;
-
- default:
-
- }
- });
- if (this._isMounted) {
- this.setState({
- experimentAPI: res.data,
- trialProfile: trialPro[0],
- searchSpace: searchSpace,
- isLogCollection: expLogCollection,
- isMultiPhase: isMultiy
- });
- }
- }
- });
- this.checkStatus();
-
- }
-
- checkStatus = () => {
- axios(`${MANAGER_IP}/check-status`, {
- method: 'GET'
- })
- .then(res => {
- if (res.status === 200) {
- const errors = res.data.errors;
- if (errors.length !== 0) {
- if (this._isMounted) {
- this.setState({
- status: res.data.status,
- errorStr: res.data.errors[0]
- });
- }
- } else {
- if (this._isMounted) {
- this.setState({
- status: res.data.status,
- });
- }
- }
- }
- });
- }
-
- showTrials = () => {
- this.isOffInterval();
- axios(`${MANAGER_IP}/trial-jobs`, {
- method: 'GET'
- })
- .then(res => {
- if (res.status === 200) {
- const tableData = res.data;
- const topTableData: Array = [];
- const profile: TrialNumber = {
- succTrial: 0,
- failTrial: 0,
- stopTrial: 0,
- waitTrial: 0,
- runTrial: 0,
- unknowTrial: 0,
- totalCurrentTrial: 0
- };
- // currently totoal number
- profile.totalCurrentTrial = tableData.length;
- Object.keys(tableData).map(item => {
- switch (tableData[item].status) {
- case 'WAITING':
- profile.waitTrial += 1;
- break;
-
- case 'UNKNOWN':
- profile.unknowTrial += 1;
- break;
-
- case 'FAILED':
- profile.failTrial += 1;
- break;
-
- case 'RUNNING':
- profile.runTrial += 1;
- break;
-
- case 'USER_CANCELED':
- case 'SYS_CANCELED':
- case 'EARLY_STOPPED':
- profile.stopTrial += 1;
- break;
- case 'SUCCEEDED':
- profile.succTrial += 1;
- const desJobDetail: Parameters = {
- parameters: {},
- intermediate: [],
- multiProgress: 1
- };
- const duration = (tableData[item].endTime - tableData[item].startTime) / 1000;
- const acc = getFinal(tableData[item].finalMetricData);
- // if hyperparameters is undefine, show error message, else, show parameters value
- const tempara = tableData[item].hyperParameters;
- if (tempara !== undefined) {
- const tempLength = tempara.length;
- const parameters = JSON.parse(tempara[tempLength - 1]).parameters;
- desJobDetail.multiProgress = tempara.length;
- if (typeof parameters === 'string') {
- desJobDetail.parameters = JSON.parse(parameters);
- } else {
- desJobDetail.parameters = parameters;
- }
- } else {
- desJobDetail.parameters = { error: 'This trial\'s parameters are not available.' };
- }
- if (tableData[item].logPath !== undefined) {
- desJobDetail.logPath = tableData[item].logPath;
- }
- topTableData.push({
- key: topTableData.length,
- sequenceId: tableData[item].sequenceId,
- id: tableData[item].id,
- duration: duration,
- status: tableData[item].status,
- acc: acc,
- description: desJobDetail
- });
- break;
- default:
- }
- });
- // choose top10 or lowest10
- const { isTop10 } = this.state;
- if (isTop10 === true) {
- topTableData.sort((a: TableObj, b: TableObj) => {
- if (a.acc !== undefined && b.acc !== undefined) {
- return JSON.parse(b.acc.default) - JSON.parse(a.acc.default);
- } else {
- return NaN;
- }
- });
- } else {
- topTableData.sort((a: TableObj, b: TableObj) => {
- if (a.acc !== undefined && b.acc !== undefined) {
- return JSON.parse(a.acc.default) - JSON.parse(b.acc.default);
- } else {
- return NaN;
- }
- });
- }
- topTableData.length = Math.min(10, topTableData.length);
- let bestDefaultMetric = 0;
- if (topTableData[0] !== undefined) {
- if (topTableData[0].acc !== undefined) {
- bestDefaultMetric = JSON.parse(topTableData[0].acc.default);
- }
- }
- if (this._isMounted) {
- this.setState({
- tableData: topTableData,
- trialNumber: profile,
- bestAccuracy: bestDefaultMetric
- });
- }
- this.checkStatus();
- // draw accuracy
- this.drawPointGraph();
- }
- });
- }
-
- // trial accuracy graph Default Metric
- drawPointGraph = () => {
-
- const { tableData } = this.state;
- const sourcePoint = JSON.parse(JSON.stringify(tableData));
- sourcePoint.sort((a: TableObj, b: TableObj) => {
- if (a.sequenceId !== undefined && b.sequenceId !== undefined) {
- return a.sequenceId - b.sequenceId;
- } else {
- return NaN;
- }
- });
- const accarr: Array = [];
- const indexarr: Array = [];
- Object.keys(sourcePoint).map(item => {
- const items = sourcePoint[item];
- if (items.acc !== undefined) {
- accarr.push(items.acc.default);
- indexarr.push(items.sequenceId);
- }
- });
- const accOption = {
- // support max show 0.0000000
- grid: {
- left: 67,
- right: 40
- },
- tooltip: {
- trigger: 'item'
- },
- xAxis: {
- name: 'Trial',
- type: 'category',
- data: indexarr
- },
- yAxis: {
- name: 'Default metric',
- type: 'value',
- scale: true,
- data: accarr
- },
- series: [{
- symbolSize: 6,
- type: 'scatter',
- data: accarr
- }]
- };
- if (this._isMounted) {
- this.setState({ accuracyData: accOption }, () => {
- if (accarr.length === 0) {
- this.setState({
- accNodata: 'No data'
- });
- } else {
- this.setState({
- accNodata: ''
- });
- }
- });
- }
- }
-
clickMaxTop = (event: React.SyntheticEvent) => {
event.stopPropagation();
// #999 panel active bgcolor; #b3b3b3 as usual
- this.setState(() => ({ isTop10: true, titleMaxbgcolor: '#999', titleMinbgcolor: '#b3b3b3' }));
- this.showTrials();
+ this.setState({ metricGraphMode: 'max' });
}
clickMinTop = (event: React.SyntheticEvent) => {
event.stopPropagation();
- this.setState(() => ({ isTop10: false, titleMaxbgcolor: '#b3b3b3', titleMinbgcolor: '#999' }));
- this.showTrials();
- }
-
- isOffInterval = () => {
- const { status } = this.state;
- const { interval } = this.props;
- if (status === 'DONE' || status === 'ERROR' || status === 'STOPPED' ||
- interval === 0
- ) {
- window.clearInterval(this.intervalID);
- window.clearInterval(this.intervalProfile);
- return;
- }
+ this.setState({ metricGraphMode: 'min' });
}
- componentWillReceiveProps(nextProps: OverviewProps) {
- const { interval, whichPageToFresh } = nextProps;
- window.clearInterval(this.intervalID);
- window.clearInterval(this.intervalProfile);
- if (whichPageToFresh.includes('/oview')) {
- this.showTrials();
- this.showSessionPro();
- }
- if (interval !== 0) {
- this.intervalID = window.setInterval(this.showTrials, interval * 1000);
- this.intervalProfile = window.setInterval(this.showSessionPro, interval * 1000);
- }
+ changeConcurrency = (val: number) => {
+ this.setState({ trialConcurrency: val });
}
- componentDidMount() {
- this._isMounted = true;
- const { interval } = this.props;
- this.showTrials();
- this.showSessionPro();
- if (interval !== 0) {
- this.intervalID = window.setInterval(this.showTrials, interval * 1000);
- this.intervalProfile = window.setInterval(this.showSessionPro, interval * 1000);
- }
- }
+ render() {
+ const { trialConcurrency, metricGraphMode } = this.state;
+ const { experimentUpdateBroadcast } = this.props;
- componentWillUnmount() {
- this._isMounted = false;
- window.clearInterval(this.intervalID);
- window.clearInterval(this.intervalProfile);
- }
+ const searchSpace = this.convertSearchSpace();
- render() {
+ const bestTrials = this.findBestTrials();
+ const bestAccuracy = bestTrials.length > 0 ? bestTrials[0].accuracy! : NaN;
+ const accuracyGraphData = this.generateAccuracyGraph(bestTrials);
+ const noDataMessage = bestTrials.length > 0 ? '' : 'No data';
- const {
- trialProfile, searchSpace, tableData, accuracyData,
- accNodata, status, errorStr, trialNumber, bestAccuracy, isMultiPhase,
- titleMaxbgcolor, titleMinbgcolor, isLogCollection, experimentAPI
- } = this.state;
- const { concurrency } = this.props;
- trialProfile.runConcurren = concurrency;
- Object.keys(experimentAPI).map(item => {
- if (item === 'params') {
- const temp = experimentAPI[item];
- Object.keys(temp).map(index => {
- if (index === 'trialConcurrency') {
- temp[index] = concurrency;
- }
- });
- }
- });
+ const titleMaxbgcolor = (metricGraphMode === 'max' ? '#999' : '#b3b3b3');
+ const titleMinbgcolor = (metricGraphMode === 'min' ? '#999' : '#b3b3b3');
return (
{/* status and experiment block */}
-
+
{/* status graph */}
{/* experiment parameters search space tuner assessor... */}
@@ -512,7 +94,10 @@ class Overview extends React.Component {
{/* the scroll bar all the trial profile in the searchSpace div*/}
-
+
@@ -541,24 +126,79 @@ class Overview extends React.Component {
-
+ trial.info.id)}/>
);
}
+
+ private convertSearchSpace(): object {
+ const searchSpace = Object.assign({}, EXPERIMENT.searchSpace);
+ Object.keys(searchSpace).map(item => {
+ const key = searchSpace[item]._type;
+ let value = searchSpace[item]._value;
+ switch (key) {
+ case 'quniform':
+ case 'qnormal':
+ case 'qlognormal':
+ searchSpace[item]._value = [value[0], value[1]];
+ break;
+ default:
+ }
+ });
+ return searchSpace;
+ }
+
+ private findBestTrials(): Trial[] {
+ let bestTrials = TRIALS.sort();
+ if (this.state.metricGraphMode === 'max') {
+ bestTrials.reverse().splice(10);
+ } else {
+ bestTrials.splice(10);
+ }
+ return bestTrials;
+ }
+
+ private generateAccuracyGraph(bestTrials: Trial[]): object {
+ const xSequence = bestTrials.map(trial => trial.sequenceId);
+ const ySequence = bestTrials.map(trial => trial.accuracy);
+
+ return {
+ // support max show 0.0000000
+ grid: {
+ left: 67,
+ right: 40
+ },
+ tooltip: {
+ trigger: 'item'
+ },
+ xAxis: {
+ name: 'Trial',
+ type: 'category',
+ data: xSequence
+ },
+ yAxis: {
+ name: 'Default metric',
+ type: 'value',
+ scale: true,
+ data: ySequence
+ },
+ series: [{
+ symbolSize: 6,
+ type: 'scatter',
+ data: ySequence
+ }]
+ };
+ }
}
+
export default Overview;
diff --git a/src/webui/src/components/SlideBar.tsx b/src/webui/src/components/SlideBar.tsx
index 5e2e0edb5f..1200fb8a7d 100644
--- a/src/webui/src/components/SlideBar.tsx
+++ b/src/webui/src/components/SlideBar.tsx
@@ -26,7 +26,6 @@ interface SliderState {
interface SliderProps extends FormComponentProps {
changeInterval: (value: number) => void;
- changeFresh: (value: string) => void;
}
interface EventPer {
@@ -35,7 +34,6 @@ interface EventPer {
class SlideBar extends React.Component {
- public _isMounted = false;
public divMenu: HTMLDivElement | null;
public selectHTML: Select | null;
@@ -57,32 +55,26 @@ class SlideBar extends React.Component {
method: 'GET'
})
.then(res => {
- if (res.status === 200 && this._isMounted) {
+ if (res.status === 200) {
this.setState({ version: res.data });
}
});
}
handleMenuClick = (e: EventPer) => {
- if (this._isMounted) { this.setState({ menuVisible: false }); }
+ this.setState({ menuVisible: false });
switch (e.key) {
// to see & download experiment parameters
case '1':
- if (this._isMounted === true) {
- this.setState(() => ({ isvisibleExperimentDrawer: true }));
- }
+ this.setState({ isvisibleExperimentDrawer: true });
break;
// to see & download nnimanager log
case '2':
- if (this._isMounted === true) {
- this.setState(() => ({ activeKey: 'nnimanager', isvisibleLogDrawer: true }));
- }
+ this.setState({ activeKey: 'nnimanager', isvisibleLogDrawer: true });
break;
// to see & download dispatcher log
case '3':
- if (this._isMounted === true) {
- this.setState(() => ({ isvisibleLogDrawer: true, activeKey: 'dispatcher' }));
- }
+ this.setState({ isvisibleLogDrawer: true, activeKey: 'dispatcher' });
break;
case 'close':
case '10':
@@ -96,13 +88,10 @@ class SlideBar extends React.Component {
}
handleVisibleChange = (flag: boolean) => {
- if (this._isMounted === true) {
- this.setState({ menuVisible: flag });
- }
+ this.setState({ menuVisible: flag });
}
getInterval = (value: string) => {
-
if (value === 'close') {
this.props.changeInterval(0);
} else {
@@ -203,13 +192,9 @@ class SlideBar extends React.Component {
fresh = (event: React.SyntheticEvent) => {
event.preventDefault();
event.stopPropagation();
- if (this._isMounted) {
- this.setState({ isdisabledFresh: true }, () => {
- const whichPage = window.location.pathname;
- this.props.changeFresh(whichPage);
- setTimeout(() => { this.setState(() => ({ isdisabledFresh: false })); }, 1000);
- });
- }
+ this.setState({ isdisabledFresh: true }, () => {
+ setTimeout(() => { this.setState({ isdisabledFresh: false }); }, 1000);
+ });
}
desktopHTML = () => {
@@ -223,6 +208,16 @@ class SlideBar extends React.Component {
{DETAILTABS}
+
+
+
+ Help
+
+
+
{this.select()}
{
View
{
menuVisible
- ?
-
- :
-
+ ?
+
+ :
+
}
@@ -320,27 +315,18 @@ class SlideBar extends React.Component {
}
// close log drawer (nnimanager.dispatcher)
closeLogDrawer = () => {
- if (this._isMounted === true) {
- this.setState(() => ({ isvisibleLogDrawer: false, activeKey: '' }));
- }
+ this.setState({ isvisibleLogDrawer: false, activeKey: '' });
}
// close download experiment parameters drawer
closeExpDrawer = () => {
- if (this._isMounted === true) {
- this.setState(() => ({ isvisibleExperimentDrawer: false }));
- }
+ this.setState({ isvisibleExperimentDrawer: false });
}
componentDidMount() {
- this._isMounted = true;
this.getNNIversion();
}
- componentWillUnmount() {
- this._isMounted = false;
- }
-
render() {
const mobile = ({this.mobileHTML()} );
const tablet = ({this.tabeltHTML()} );
@@ -366,4 +352,4 @@ class SlideBar extends React.Component {
}
}
-export default Form.create()(SlideBar);
\ No newline at end of file
+export default Form.create()(SlideBar);
diff --git a/src/webui/src/components/TrialsDetail.tsx b/src/webui/src/components/TrialsDetail.tsx
index c5ee9024ab..b6f1897f29 100644
--- a/src/webui/src/components/TrialsDetail.tsx
+++ b/src/webui/src/components/TrialsDetail.tsx
@@ -1,10 +1,8 @@
import * as React from 'react';
-import axios from 'axios';
-import { MANAGER_IP } from '../static/const';
import { Row, Col, Tabs, Select, Button, Icon } from 'antd';
const Option = Select.Option;
-import { TableObj, Parameters, ExperimentInfo } from '../static/interface';
-import { getFinal } from '../static/function';
+import { EXPERIMENT, TRIALS } from '../static/datamodel';
+import { Trial } from '../static/model/trial';
import DefaultPoint from './trial-detail/DefaultMetricPoint';
import Duration from './trial-detail/Duration';
import Title1 from './overview/Title1';
@@ -16,37 +14,22 @@ import '../static/style/trialsDetail.scss';
import '../static/style/search.scss';
interface TrialDetailState {
- accSource: object;
- accNodata: string;
- tableListSource: Array;
- searchResultSource: Array;
- isHasSearch: boolean;
- experimentLogCollection: boolean;
- entriesTable: number; // table components val
- entriesInSelect: string;
- searchSpace: string;
- isMultiPhase: boolean;
+ tablePageSize: number; // table components val
whichGraph: string;
- hyperCounts: number; // user click the hyper-parameter counts
- durationCounts: number;
- intermediateCounts: number;
- experimentInfo: ExperimentInfo;
- searchFilter: string;
- searchPlaceHolder: string;
+ searchType: string;
+ searchFilter: (trial: Trial) => boolean;
}
interface TrialsDetailProps {
- interval: number;
- whichPageToFresh: string;
columnList: Array;
changeColumn: (val: Array) => void;
+ experimentUpdateBroacast: number;
+ trialsUpdateBroadcast: number;
}
class TrialsDetail extends React.Component {
- public _isMounted = false;
public interAccuracy = 0;
- public interTableList = 1;
public interAllTableList = 2;
public tableList: TableList | null;
@@ -73,333 +56,67 @@ class TrialsDetail extends React.Component
constructor(props: TrialsDetailProps) {
super(props);
-
this.state = {
- accSource: {},
- accNodata: '',
- tableListSource: [],
- searchResultSource: [],
- experimentLogCollection: false,
- entriesTable: 20,
- entriesInSelect: '20',
- searchSpace: '',
+ tablePageSize: 20,
whichGraph: '1',
- isHasSearch: false,
- isMultiPhase: false,
- hyperCounts: 0,
- durationCounts: 0,
- intermediateCounts: 0,
- experimentInfo: {
- platform: '',
- optimizeMode: 'maximize'
- },
- searchFilter: 'id',
- searchPlaceHolder: 'Search by id'
+ searchType: 'id',
+ searchFilter: trial => true,
};
}
- getDetailSource = () => {
- this.isOffIntervals();
- axios
- .all([
- axios.get(`${MANAGER_IP}/trial-jobs`),
- axios.get(`${MANAGER_IP}/metric-data`)
- ])
- .then(axios.spread((res, res1) => {
- if (res.status === 200 && res1.status === 200) {
- const trialJobs = res.data;
- const metricSource = res1.data;
- const trialTable: Array = [];
- Object.keys(trialJobs).map(item => {
- let desc: Parameters = {
- parameters: {},
- intermediate: [],
- multiProgress: 1
- };
- let duration = 0;
- const id = trialJobs[item].id !== undefined
- ? trialJobs[item].id
- : '';
- const status = trialJobs[item].status !== undefined
- ? trialJobs[item].status
- : '';
- const begin = trialJobs[item].startTime;
- const end = trialJobs[item].endTime;
- if (begin) {
- if (end) {
- duration = (end - begin) / 1000;
- } else {
- duration = (new Date().getTime() - begin) / 1000;
- }
- }
- const tempHyper = trialJobs[item].hyperParameters;
- if (tempHyper !== undefined) {
- const getPara = JSON.parse(tempHyper[tempHyper.length - 1]).parameters;
- desc.multiProgress = tempHyper.length;
- if (typeof getPara === 'string') {
- desc.parameters = JSON.parse(getPara);
- } else {
- desc.parameters = getPara;
- }
- } else {
- desc.parameters = { error: 'This trial\'s parameters are not available.' };
- }
- if (trialJobs[item].logPath !== undefined) {
- desc.logPath = trialJobs[item].logPath;
- }
-
- const acc = getFinal(trialJobs[item].finalMetricData);
- // deal with intermediate result list
- const mediate: Array = [];
- Object.keys(metricSource).map(key => {
- const items = metricSource[key];
- if (items.trialJobId === id) {
- // succeed trial, last intermediate result is final result
- // final result format may be object
- if (typeof JSON.parse(items.data) === 'object') {
- mediate.push(JSON.parse(items.data).default);
- } else {
- mediate.push(JSON.parse(items.data));
- }
- }
- });
- desc.intermediate = mediate;
- trialTable.push({
- key: trialTable.length,
- sequenceId: trialJobs[item].sequenceId,
- id: id,
- status: status,
- duration: duration,
- acc: acc,
- description: desc
- });
- });
- // update search data result
- const { searchResultSource, entriesInSelect } = this.state;
- if (searchResultSource.length !== 0) {
- const temp: Array = [];
- Object.keys(searchResultSource).map(index => {
- temp.push(searchResultSource[index].id);
- });
- const searchResultList: Array = [];
- for (let i = 0; i < temp.length; i++) {
- Object.keys(trialTable).map(key => {
- const item = trialTable[key];
- if (item.id === temp[i]) {
- searchResultList.push(item);
- }
- });
- }
-
- if (this._isMounted) {
- this.setState(() => ({
- searchResultSource: searchResultList
- }));
- }
- }
- if (this._isMounted) {
- this.setState(() => ({ tableListSource: trialTable }));
- }
- if (entriesInSelect === 'all' && this._isMounted) {
- this.setState(() => ({
- entriesTable: trialTable.length
- }));
- }
- }
- }));
- }
-
// search a trial by trial No. & trial id
searchTrial = (event: React.ChangeEvent) => {
const targetValue = event.target.value;
- if (targetValue === '' || targetValue === ' ') {
- const { tableListSource } = this.state;
- if (this._isMounted) {
- this.setState(() => ({
- isHasSearch: false,
- tableListSource: tableListSource,
- }));
- }
- } else {
- const { tableListSource, searchFilter } = this.state;
- const searchResultList: Array = [];
- Object.keys(tableListSource).map(key => {
- const item = tableListSource[key];
- switch (searchFilter) {
- case 'id':
- if (item.id.toUpperCase().includes(targetValue.toUpperCase())) {
- searchResultList.push(item);
- }
- break;
- case 'Trial No.':
- if (item.sequenceId.toString() === targetValue) {
- searchResultList.push(item);
- }
- break;
- case 'status':
- if (item.status.toUpperCase().includes(targetValue.toUpperCase())) {
- searchResultList.push(item);
- }
- break;
- case 'parameters':
- const strParameters = JSON.stringify(item.description.parameters, null, 4);
- if (strParameters.includes(targetValue)) {
- searchResultList.push(item);
- }
- break;
- default:
- }
- });
- if (this._isMounted) {
- this.setState(() => ({
- searchResultSource: searchResultList,
- isHasSearch: true
- }));
- }
- }
- }
-
- // close timer
- isOffIntervals = () => {
- const { interval } = this.props;
- if (interval === 0) {
- window.clearInterval(this.interTableList);
+ let filter = (trial: Trial) => true;
+ if (!targetValue.trim()) {
+ this.setState({ searchFilter: filter });
return;
- } else {
- axios(`${MANAGER_IP}/check-status`, {
- method: 'GET'
- })
- .then(res => {
- if (res.status === 200 && this._isMounted) {
- const expStatus = res.data.status;
- if (expStatus === 'DONE' || expStatus === 'ERROR' || expStatus === 'STOPPED') {
- window.clearInterval(this.interTableList);
- return;
- }
- }
- });
}
+ switch (this.state.searchType) {
+ case 'id':
+ filter = trial => trial.info.id.toUpperCase().includes(targetValue.toUpperCase());
+ break;
+ case 'Trial No.':
+ filter = trial => trial.info.sequenceId.toString() === targetValue;
+ break;
+ case 'status':
+ filter = trial => trial.info.status.toUpperCase().includes(targetValue.toUpperCase());
+ break;
+ case 'parameters':
+ // TODO: support filters like `x: 2` (instead of `"x": 2`)
+ filter = trial => JSON.stringify(trial.info.hyperParameters, null, 4).includes(targetValue);
+ break;
+ default:
+ alert(`Unexpected search filter ${this.state.searchType}`);
+ }
+ this.setState({ searchFilter: filter });
}
- handleEntriesSelect = (value: string) => {
- // user select isn't 'all'
- if (value !== 'all') {
- if (this._isMounted) {
- this.setState(() => ({ entriesTable: parseInt(value, 10) }));
- }
- } else {
- const { tableListSource } = this.state;
- if (this._isMounted) {
- this.setState(() => ({
- entriesInSelect: 'all',
- entriesTable: tableListSource.length
- }));
- }
- }
+ handleTablePageSizeSelect = (value: string) => {
+ this.setState({ tablePageSize: value === 'all' ? -1 : parseInt(value, 10) });
}
handleWhichTabs = (activeKey: string) => {
- // const which = JSON.parse(activeKey);
- if (this._isMounted) {
- this.setState(() => ({ whichGraph: activeKey }));
- }
+ this.setState({ whichGraph: activeKey });
}
test = () => {
alert('TableList component was not properly initialized.');
}
- getSearchFilter = (value: string) => {
+ updateSearchFilterType = (value: string) => {
// clear input value and re-render table
if (this.searchInput !== null) {
this.searchInput.value = '';
- if (this._isMounted === true) {
- this.setState(() => ({ isHasSearch: false }));
- }
- }
- if (this._isMounted === true) {
- this.setState(() => ({ searchFilter: value, searchPlaceHolder: `Search by ${value}` }));
}
- }
-
- // get and set logCollection val
- checkExperimentPlatform = () => {
- axios(`${MANAGER_IP}/experiment`, {
- method: 'GET'
- })
- .then(res => {
- if (res.status === 200) {
- const trainingPlatform: string = res.data.params.trainingServicePlatform !== undefined
- ?
- res.data.params.trainingServicePlatform
- :
- '';
- // default logCollection is true
- const logCollection = res.data.params.logCollection;
- let expLogCollection: boolean = false;
- const isMultiy: boolean = res.data.params.multiPhase !== undefined
- ? res.data.params.multiPhase : false;
- const tuner = res.data.params.tuner;
- // I'll set optimize is maximize if user not set optimize
- let optimize: string = 'maximize';
- if (tuner !== undefined) {
- if (tuner.classArgs !== undefined) {
- if (tuner.classArgs.optimize_mode !== undefined) {
- if (tuner.classArgs.optimize_mode === 'minimize') {
- optimize = 'minimize';
- }
- }
- }
- }
- if (logCollection !== undefined && logCollection !== 'none') {
- expLogCollection = true;
- }
- if (this._isMounted) {
- this.setState({
- experimentInfo: { platform: trainingPlatform, optimizeMode: optimize },
- searchSpace: res.data.params.searchSpace,
- experimentLogCollection: expLogCollection,
- isMultiPhase: isMultiy
- });
- }
- }
- });
- }
-
- componentWillReceiveProps(nextProps: TrialsDetailProps) {
- const { interval, whichPageToFresh } = nextProps;
- window.clearInterval(this.interTableList);
- if (interval !== 0) {
- this.interTableList = window.setInterval(this.getDetailSource, interval * 1000);
- }
- if (whichPageToFresh.includes('/detail')) {
- this.getDetailSource();
- }
- }
-
- componentDidMount() {
-
- this._isMounted = true;
- const { interval } = this.props;
- this.getDetailSource();
- this.interTableList = window.setInterval(this.getDetailSource, interval * 1000);
- this.checkExperimentPlatform();
- }
-
- componentWillUnmount() {
- this._isMounted = false;
- window.clearInterval(this.interTableList);
+ this.setState({ searchType: value });
}
render() {
-
- const {
- tableListSource, searchResultSource, isHasSearch, isMultiPhase,
- entriesTable, experimentInfo, searchSpace, experimentLogCollection,
- whichGraph, searchPlaceHolder
- } = this.state;
- const source = isHasSearch ? searchResultSource : tableListSource;
+ const { tablePageSize, whichGraph } = this.state;
const { columnList, changeColumn } = this.props;
+ const source = TRIALS.filter(this.state.searchFilter);
+ const trialIds = TRIALS.filter(this.state.searchFilter).map(trial => trial.id);
return (
@@ -407,10 +124,9 @@ class TrialsDetail extends React.Component
@@ -418,7 +134,7 @@ class TrialsDetail extends React.Component
@@ -438,7 +154,7 @@ class TrialsDetail extends React.Component
Show
20
@@ -462,7 +178,7 @@ class TrialsDetail extends React.Component
>
Compare
-
+
Id
Trial No.
Status
@@ -471,7 +187,7 @@ class TrialsDetail extends React.Component
(this.searchInput) = text}
@@ -479,14 +195,11 @@ class TrialsDetail extends React.Component
trial.tableRecord)}
columnList={columnList}
changeColumn={changeColumn}
+ trialsUpdateBroadcast={this.props.trialsUpdateBroadcast}
ref={(tabList) => this.tableList = tabList}
/>
@@ -494,4 +207,4 @@ class TrialsDetail extends React.Component
}
}
-export default TrialsDetail;
\ No newline at end of file
+export default TrialsDetail;
diff --git a/src/webui/src/components/overview/BasicInfo.tsx b/src/webui/src/components/overview/BasicInfo.tsx
index dfddde7a1e..b47fca53f0 100644
--- a/src/webui/src/components/overview/BasicInfo.tsx
+++ b/src/webui/src/components/overview/BasicInfo.tsx
@@ -1,68 +1,45 @@
+import { Col, Row, Tooltip } from 'antd';
import * as React from 'react';
-import {
- Row, Col,
- Tooltip
-} from 'antd';
-import { Experiment } from '../../static/interface';
+import { EXPERIMENT } from '../../static/datamodel';
+import { formatTimestamp } from '../../static/function';
interface BasicInfoProps {
- trialProfile: Experiment;
- status: string;
+ experimentUpdateBroadcast: number;
}
class BasicInfo extends React.Component {
-
constructor(props: BasicInfoProps) {
super(props);
}
render() {
- const { trialProfile } = this.props;
return (
Name
- {trialProfile.experName}
+ {EXPERIMENT.profile.params.experimentName}
ID
- {trialProfile.id}
+ {EXPERIMENT.profile.id}
Start time
-
- {new Date(trialProfile.startTime).toLocaleString('en-US')}
-
+ {formatTimestamp(EXPERIMENT.profile.startTime)}
End time
-
- {
- trialProfile.endTime
- ?
- new Date(trialProfile.endTime).toLocaleString('en-US')
- :
- 'none'
- }
-
+ {formatTimestamp(EXPERIMENT.profile.endTime)}
Log directory
-
- {trialProfile.logDir}
+
+ {EXPERIMENT.profile.logDir || 'unknown'}
Training platform
-
- {
- trialProfile.trainingServicePlatform
- ?
- trialProfile.trainingServicePlatform
- :
- 'none'
- }
-
+ {EXPERIMENT.profile.params.trainingServicePlatform}
);
}
}
-export default BasicInfo;
\ No newline at end of file
+export default BasicInfo;
diff --git a/src/webui/src/components/overview/NumInput.tsx b/src/webui/src/components/overview/NumInput.tsx
new file mode 100644
index 0000000000..0c014a3233
--- /dev/null
+++ b/src/webui/src/components/overview/NumInput.tsx
@@ -0,0 +1,85 @@
+import * as React from 'react';
+import { Button, Row } from 'antd';
+
+interface ConcurrencyInputProps {
+ value: number;
+ updateValue: (val: string) => void;
+}
+
+interface ConcurrencyInputStates {
+ editting: boolean;
+}
+
+class ConcurrencyInput extends React.Component {
+ private input = React.createRef();
+
+ constructor(props: ConcurrencyInputProps) {
+ super(props);
+ this.state = { editting: false };
+ }
+
+ save = () => {
+ if (this.input.current !== null) {
+ this.props.updateValue(this.input.current.value);
+ this.setState({ editting: false });
+ }
+ }
+
+ cancel = () => {
+ this.setState({ editting: false });
+ }
+
+ edit = () => {
+ this.setState({ editting: true });
+ }
+
+ render() {
+ if (this.state.editting) {
+ return (
+
+
+
+ Save
+
+
+ Cancel
+
+
+ );
+ } else {
+ return (
+
+
+
+ Edit
+
+
+ );
+ }
+ }
+}
+
+export default ConcurrencyInput;
diff --git a/src/webui/src/components/overview/Progress.tsx b/src/webui/src/components/overview/Progress.tsx
index 39d6ee3322..398fbf2ff7 100644
--- a/src/webui/src/components/overview/Progress.tsx
+++ b/src/webui/src/components/overview/Progress.tsx
@@ -1,192 +1,99 @@
import * as React from 'react';
-import { Row, Col, Popover, Button, message } from 'antd';
+import { Row, Col, Popover, message } from 'antd';
import axios from 'axios';
-import { MANAGER_IP, CONTROLTYPE } from '../../static/const';
-import { Experiment, TrialNumber } from '../../static/interface';
+import { MANAGER_IP } from '../../static/const';
+import { EXPERIMENT, TRIALS } from '../../static/datamodel';
import { convertTime } from '../../static/function';
+import ConcurrencyInput from './NumInput';
import ProgressBar from './ProgressItem';
import LogDrawer from '../Modal/LogDrawer';
import '../../static/style/progress.scss';
import '../../static/style/probar.scss';
interface ProgressProps {
- trialProfile: Experiment;
concurrency: number;
- trialNumber: TrialNumber;
bestAccuracy: number;
- status: string;
- errors: string;
changeConcurrency: (val: number) => void;
+ experimentUpdateBroadcast: number;
}
interface ProgressState {
- btnName: string;
- isEnable: boolean;
- userInputVal: string; // get user input
- cancelSty: string;
isShowLogDrawer: boolean;
}
class Progressed extends React.Component {
-
- public conInput: HTMLInputElement | null;
- public _isMounted = false;
constructor(props: ProgressProps) {
super(props);
this.state = {
- btnName: 'Edit',
- isEnable: true,
- userInputVal: this.props.trialProfile.runConcurren.toString(),
- cancelSty: 'none',
isShowLogDrawer: false
};
}
- editTrialConcurrency = () => {
- const { btnName } = this.state;
- if (this._isMounted) {
- if (btnName === 'Edit') {
- // user click edit
- this.setState(() => ({
- isEnable: false,
- btnName: 'Save',
- cancelSty: 'inline-block'
- }));
- } else {
- // user click save button
- axios(`${MANAGER_IP}/experiment`, {
- method: 'GET'
- })
- .then(rese => {
- if (rese.status === 200) {
- const { userInputVal } = this.state;
- const experimentFile = rese.data;
- const trialConcurrency = experimentFile.params.trialConcurrency;
- if (userInputVal !== undefined) {
- if (userInputVal === trialConcurrency.toString() || userInputVal === '0') {
- message.destroy();
- message.info(
- `trialConcurrency's value is ${trialConcurrency}, you did not modify it`, 2);
- } else {
- experimentFile.params.trialConcurrency = parseInt(userInputVal, 10);
- // rest api, modify trial concurrency value
- axios(`${MANAGER_IP}/experiment`, {
- method: 'PUT',
- headers: {
- 'Content-Type': 'application/json;charset=utf-8'
- },
- data: experimentFile,
- params: {
- update_type: CONTROLTYPE[1]
- }
- }).then(res => {
- if (res.status === 200) {
- message.destroy();
- message.success(`Update ${CONTROLTYPE[1].toLocaleLowerCase()}
- successfully`);
- this.props.changeConcurrency(parseInt(userInputVal, 10));
- }
- })
- .catch(error => {
- if (error.response.status === 500) {
- if (error.response.data.error) {
- message.error(error.response.data.error);
- } else {
- message.error(
- `Update ${CONTROLTYPE[1].toLocaleLowerCase()} failed`);
- }
- }
- });
- // btn -> edit
- this.setState(() => ({
- btnName: 'Edit',
- isEnable: true,
- cancelSty: 'none'
- }));
- }
- }
- }
- });
- }
- }
- }
-
- cancelFunction = () => {
- const { trialProfile } = this.props;
- if (this._isMounted) {
- this.setState(
- () => ({
- btnName: 'Edit',
- isEnable: true,
- cancelSty: 'none',
- }));
+ editTrialConcurrency = async (userInput: string) => {
+ if (!userInput.match(/^[1-9]\d*$/)) {
+ message.error('Please enter a positive integer!', 2);
+ return;
}
- if (this.conInput !== null) {
- this.conInput.value = trialProfile.runConcurren.toString();
+ const newConcurrency = parseInt(userInput, 10);
+ if (newConcurrency === this.props.concurrency) {
+ message.info(`Trial concurrency has not changed`, 2);
+ return;
}
- }
- getUserTrialConcurrency = (event: React.ChangeEvent) => {
- const value = event.target.value;
- if (value.match(/^[1-9]\d*$/) || value === '') {
- this.setState(() => ({
- userInputVal: value
- }));
- } else {
- message.error('Please enter a positive integer!', 2);
- if (this.conInput !== null) {
- const { trialProfile } = this.props;
- this.conInput.value = trialProfile.runConcurren.toString();
+ const newProfile = Object.assign({}, EXPERIMENT.profile);
+ newProfile.params.trialConcurrency = newConcurrency;
+
+ // rest api, modify trial concurrency value
+ try {
+ const res = await axios.put(`${MANAGER_IP}/experiment`, newProfile, {
+ params: { update_type: 'TRIAL_CONCURRENCY' }
+ });
+ if (res.status === 200) {
+ message.success(`Successfully updated trial concurrency`);
+ // NOTE: should we do this earlier in favor of poor networks?
+ this.props.changeConcurrency(newConcurrency);
+ }
+ } catch (error) {
+ if (error.response && error.response.data.error) {
+ message.error(`Failed to update trial concurrency\n${error.response.data.error}`);
+ } else if (error.response) {
+ message.error(`Failed to update trial concurrency\nServer responsed ${error.response.status}`);
+ } else if (error.message) {
+ message.error(`Failed to update trial concurrency\n${error.message}`);
+ } else {
+ message.error(`Failed to update trial concurrency\nUnknown error`);
}
}
}
isShowDrawer = () => {
- if (this._isMounted === true) {
- this.setState(() => ({ isShowLogDrawer: true }));
- }
+ this.setState({ isShowLogDrawer: true });
}
closeDrawer = () => {
- if (this._isMounted === true) {
- this.setState(() => ({ isShowLogDrawer: false }));
- }
+ this.setState({ isShowLogDrawer: false });
}
- componentWillReceiveProps() {
- const { trialProfile } = this.props;
- if (this.conInput !== null) {
- this.conInput.value = trialProfile.runConcurren.toString();
- }
- }
+ render() {
+ const { bestAccuracy } = this.props;
+ const { isShowLogDrawer } = this.state;
- componentDidMount() {
- this._isMounted = true;
- }
+ const count = TRIALS.countStatus();
+ const stoppedCount = count.get('USER_CANCELED')! + count.get('SYS_CANCELED')! + count.get('EARLY_STOPPED')!;
+ const bar2 = count.get('RUNNING')! + count.get('SUCCEEDED')! + count.get('FAILED')! + stoppedCount;
- componentWillUnmount() {
- this._isMounted = false;
- }
+ const bar2Percent = (bar2 / EXPERIMENT.profile.params.maxTrialNum) * 100;
+ const percent = (EXPERIMENT.profile.execDuration / EXPERIMENT.profile.params.maxExecDuration) * 100;
+ const remaining = convertTime(EXPERIMENT.profile.params.maxExecDuration - EXPERIMENT.profile.execDuration);
+ const maxDuration = convertTime(EXPERIMENT.profile.params.maxExecDuration);
+ const maxTrialNum = EXPERIMENT.profile.params.maxTrialNum;
+ const execDuration = convertTime(EXPERIMENT.profile.execDuration);
- render() {
- const { trialProfile, trialNumber, bestAccuracy, status, errors } = this.props;
- const { isEnable, btnName, cancelSty, isShowLogDrawer } = this.state;
- const bar2 = trialNumber.totalCurrentTrial - trialNumber.waitTrial - trialNumber.unknowTrial;
- const bar2Percent = (bar2 / trialProfile.MaxTrialNum) * 100;
- const percent = (trialProfile.execDuration / trialProfile.maxDuration) * 100;
- const runDuration = convertTime(trialProfile.execDuration);
- const temp = trialProfile.maxDuration - trialProfile.execDuration;
- let remaining;
let errorContent;
- if (temp < 0) {
- remaining = '0';
- } else {
- remaining = convertTime(temp);
- }
- if (errors !== '') {
+ if (EXPERIMENT.error) {
errorContent = (
- {errors}
+ {EXPERIMENT.error}
);
@@ -196,9 +103,9 @@ class Progressed extends React.Component {
Status
-
{status}
+
{EXPERIMENT.status}
{
- status === 'ERROR'
+ EXPERIMENT.status === 'ERROR'
?
{
Best metric
- {bestAccuracy.toFixed(6)}
+ {isNaN(bestAccuracy) ? 'N/A' : bestAccuracy.toFixed(6)}
Spent
- {convertTime(trialProfile.execDuration)}
+ {execDuration}
@@ -247,54 +154,32 @@ class Progressed extends React.Component {
{/* modify concurrency */}
Concurrency
-
- this.conInput = input}
- />
- {btnName}
-
-
- Cancel
-
-
+
Running
- {trialNumber.runTrial}
+ {count.get('RUNNING')}
Succeeded
- {trialNumber.succTrial}
+ {count.get('SUCCEEDED')}
Stopped
- {trialNumber.stopTrial}
+ {stoppedCount}
Failed
- {trialNumber.failTrial}
+ {count.get('FAILED')}
@@ -309,4 +194,4 @@ class Progressed extends React.Component {
}
}
-export default Progressed;
\ No newline at end of file
+export default Progressed;
diff --git a/src/webui/src/components/overview/SuccessTable.tsx b/src/webui/src/components/overview/SuccessTable.tsx
index 18d7ee55a6..1e99d2416c 100644
--- a/src/webui/src/components/overview/SuccessTable.tsx
+++ b/src/webui/src/components/overview/SuccessTable.tsx
@@ -2,131 +2,83 @@ import * as React from 'react';
import { Table } from 'antd';
import OpenRow from '../public-child/OpenRow';
import DefaultMetric from '../public-child/DefaultMetrc';
-import { TableObj } from '../../static/interface';
+import { TRIALS } from '../../static/datamodel';
+import { TableRecord } from '../../static/interface';
import { convertDuration } from '../../static/function';
import '../../static/style/tableStatus.css';
import '../../static/style/openRow.scss';
interface SuccessTableProps {
- tableSource: Array;
- trainingPlatform: string;
- logCollection: boolean;
- multiphase: boolean;
+ trialIds: string[];
}
-class SuccessTable extends React.Component {
-
- public _isMounted = false;
+function openRow(record: TableRecord) {
+ return (
+
+ );
+}
+class SuccessTable extends React.Component {
constructor(props: SuccessTableProps) {
super(props);
-
- }
-
- openRow = (record: TableObj) => {
- const { trainingPlatform, logCollection, multiphase } = this.props;
- return (
-
- );
- }
-
- componentDidMount() {
- this._isMounted = true;
- }
-
- componentWillUnmount() {
- this._isMounted = false;
}
render() {
- const { tableSource } = this.props;
-
- let bgColor = '';
- const columns = [{
- title: 'Trial No.',
- dataIndex: 'sequenceId',
- key: 'sequenceId',
- width: 140,
- className: 'tableHead'
- }, {
- title: 'ID',
- dataIndex: 'id',
- key: 'id',
- width: 60,
- className: 'tableHead leftTitle',
- render: (text: string, record: TableObj) => {
- return (
- {record.id}
- );
- },
- }, {
- title: 'Duration',
- dataIndex: 'duration',
- key: 'duration',
- width: 140,
- sorter: (a: TableObj, b: TableObj) => (a.duration as number) - (b.duration as number),
- render: (text: string, record: TableObj) => {
- let duration;
- if (record.duration !== undefined) {
- // duration is nagative number(-1) & 0-1
- if (record.duration > 0 && record.duration < 1 || record.duration < 0) {
- duration = `${record.duration}s`;
- } else {
- duration = convertDuration(record.duration);
- }
- } else {
- duration = 0;
+ const columns = [
+ {
+ title: 'Trial No.',
+ dataIndex: 'sequenceId',
+ width: 140,
+ className: 'tableHead'
+ }, {
+ title: 'ID',
+ dataIndex: 'id',
+ width: 60,
+ className: 'tableHead leftTitle',
+ render: (text: string, record: TableRecord) => {
+ return (
+ {record.id}
+ );
+ },
+ }, {
+ title: 'Duration',
+ dataIndex: 'duration',
+ width: 140,
+ render: (text: string, record: TableRecord) => {
+ return (
+ {convertDuration(record.duration)}
+ );
+ },
+ }, {
+ title: 'Status',
+ dataIndex: 'status',
+ width: 150,
+ className: 'tableStatus',
+ render: (text: string, record: TableRecord) => {
+ return (
+ {record.status}
+ );
}
- return (
-
- );
- },
- }, {
- title: 'Status',
- dataIndex: 'status',
- key: 'status',
- width: 150,
- className: 'tableStatus',
- render: (text: string, record: TableObj) => {
- bgColor = record.status;
- return (
-
- {record.status}
-
- );
- }
- }, {
- title: 'Default metric',
- dataIndex: 'acc',
- key: 'acc',
- sorter: (a: TableObj, b: TableObj) => {
- if (a.acc !== undefined && b.acc !== undefined) {
- return JSON.parse(a.acc.default) - JSON.parse(b.acc.default);
- } else {
- return NaN;
+ }, {
+ title: 'Default metric',
+ dataIndex: 'accuracy',
+ render: (text: string, record: TableRecord) => {
+ return (
+
+ );
}
- },
- render: (text: string, record: TableObj) => {
- return (
-
- );
}
- }];
+ ];
return (
+
);
}
}
diff --git a/src/webui/src/components/overview/TrialProfile.tsx b/src/webui/src/components/overview/TrialProfile.tsx
index 4820fa7ccd..dd55dd0868 100644
--- a/src/webui/src/components/overview/TrialProfile.tsx
+++ b/src/webui/src/components/overview/TrialProfile.tsx
@@ -1,9 +1,11 @@
import * as React from 'react';
import MonacoEditor from 'react-monaco-editor';
import { MONACO } from '../../static/const';
+import { EXPERIMENT } from '../../static/datamodel';
interface TrialInfoProps {
- experiment: object;
+ experimentUpdateBroadcast: number;
+ concurrency: number;
}
class TrialInfo extends React.Component {
@@ -12,32 +14,21 @@ class TrialInfo extends React.Component {
super(props);
}
- componentWillReceiveProps(nextProps: TrialInfoProps) {
- const experiments = nextProps.experiment;
- Object.keys(experiments).map(key => {
- switch (key) {
- case 'id':
- case 'logDir':
- case 'startTime':
- case 'endTime':
- experiments[key] = undefined;
- break;
- case 'params':
- const params = experiments[key];
- Object.keys(params).map(item => {
- if (item === 'experimentName' || item === 'searchSpace'
- || item === 'trainingServicePlatform') {
- params[item] = undefined;
- }
- });
- break;
- default:
+ render() {
+ const blacklist = [
+ 'id', 'logDir', 'startTime', 'endTime',
+ 'experimentName', 'searchSpace', 'trainingServicePlatform'
+ ];
+ // tslint:disable-next-line:no-any
+ const filter = (key: string, val: any) => {
+ if (key === 'trialConcurrency') {
+ return this.props.concurrency;
}
- });
- }
+ return blacklist.includes(key) ? undefined : val;
+ };
+ const profile = JSON.stringify(EXPERIMENT.profile, filter, 2);
- render() {
- const { experiment } = this.props;
+ // FIXME: highlight not working?
return (
{
height="361"
language="json"
theme="vs-light"
- value={JSON.stringify(experiment, null, 2)}
+ value={profile}
options={MONACO}
/>
diff --git a/src/webui/src/components/public-child/DefaultMetrc.tsx b/src/webui/src/components/public-child/DefaultMetrc.tsx
index d31b288c63..02a8894419 100644
--- a/src/webui/src/components/public-child/DefaultMetrc.tsx
+++ b/src/webui/src/components/public-child/DefaultMetrc.tsx
@@ -1,45 +1,22 @@
import * as React from 'react';
-import { TableObj } from '../../static/interface';
+import { TRIALS } from '../../static/datamodel';
+import { formatAccuracy } from '../../static/function';
interface DefaultMetricProps {
- record: TableObj;
+ trialId: string;
}
class DefaultMetric extends React.Component {
-
constructor(props: DefaultMetricProps) {
super(props);
-
}
render() {
- const { record } = this.props;
- let accuracy;
- if (record.acc !== undefined) {
- accuracy = record.acc.default;
- }
- let wei = 0;
- if (accuracy !== undefined) {
- if (accuracy.toString().indexOf('.') !== -1) {
- wei = accuracy.toString().length - accuracy.toString().indexOf('.') - 1;
- }
- }
+ const accuracy = TRIALS.getTrial(this.props.trialId).accuracy;
return (
-
- {
- record.acc !== undefined && record.acc.default !== undefined
- ?
- wei > 6
- ?
- JSON.parse(record.acc.default).toFixed(6)
- :
- record.acc.default
- :
- '--'
- }
-
+ {accuracy !== undefined ? formatAccuracy(accuracy) : '--'}
);
}
}
-export default DefaultMetric;
\ No newline at end of file
+export default DefaultMetric;
diff --git a/src/webui/src/components/public-child/IntermediateVal.tsx b/src/webui/src/components/public-child/IntermediateVal.tsx
index b5bd015843..2382ed4fbc 100644
--- a/src/webui/src/components/public-child/IntermediateVal.tsx
+++ b/src/webui/src/components/public-child/IntermediateVal.tsx
@@ -1,46 +1,18 @@
import * as React from 'react';
-import { TableObj } from '../../static/interface';
+import { TRIALS } from '../../static/datamodel';
interface IntermediateValProps {
- record: TableObj;
+ trialId: string;
}
class IntermediateVal extends React.Component {
-
constructor(props: IntermediateValProps) {
super(props);
-
}
render() {
- const { record } = this.props;
- const interArr = record.description.intermediate;
- let lastVal;
- let wei = 0;
- if (interArr !== undefined) {
- lastVal = interArr[interArr.length - 1];
- }
- let result: string = JSON.stringify(lastVal);
- if (lastVal !== undefined) {
- if (lastVal.toString().indexOf('.') !== -1) {
- wei = lastVal.toString().length - lastVal.toString().indexOf('.') - 1;
- if (wei > 6) {
- result = `${lastVal.toFixed(6)}`;
- }
- }
- // some trial haven't final result
- if (record.acc !== undefined) {
- if (record.acc.default !== undefined) {
- result = `${result} (FINAL)`;
- }
- } else {
- result = `${result} (LATEST)`;
- }
- } else {
- result = '--';
- }
return (
- {result}
+ {TRIALS.getTrial(this.props.trialId).formatLatestAccuracy()}
);
}
}
diff --git a/src/webui/src/components/public-child/OpenRow.tsx b/src/webui/src/components/public-child/OpenRow.tsx
index dcaf22a6d8..c990b03b94 100644
--- a/src/webui/src/components/public-child/OpenRow.tsx
+++ b/src/webui/src/components/public-child/OpenRow.tsx
@@ -2,7 +2,8 @@ import * as React from 'react';
import * as copy from 'copy-to-clipboard';
import PaiTrialLog from '../public-child/PaiTrialLog';
import TrialLog from '../public-child/TrialLog';
-import { TableObj } from '../../static/interface';
+import { EXPERIMENT, TRIALS } from '../../static/datamodel';
+import { Trial } from '../../static/model/trial';
import { Row, Tabs, Button, message, Modal } from 'antd';
import { MANAGER_IP } from '../../static/const';
import '../../static/style/overview.scss';
@@ -11,10 +12,7 @@ import JSONTree from 'react-json-tree';
const TabPane = Tabs.TabPane;
interface OpenRowProps {
- trainingPlatform: string;
- record: TableObj;
- logCollection: boolean;
- multiphase: boolean;
+ trialId: string;
}
interface OpenRowState {
@@ -24,7 +22,6 @@ interface OpenRowState {
class OpenRow extends React.Component {
- public _isMounted: boolean;
constructor(props: OpenRowProps) {
super(props);
this.state = {
@@ -33,20 +30,16 @@ class OpenRow extends React.Component {
};
}
- showFormatModal = (record: TableObj) => {
+ showFormatModal = (trial: Trial) => {
// get copy parameters
- const params = JSON.stringify(record.description.parameters, null, 4);
+ const params = JSON.stringify(trial.info.hyperParameters, null, 4);
// open modal with format string
- if (this._isMounted === true) {
- this.setState(() => ({ isShowFormatModal: true, formatStr: params }));
- }
+ this.setState({ isShowFormatModal: true, formatStr: params });
}
hideFormatModal = () => {
// close modal, destroy state format string data
- if (this._isMounted === true) {
- this.setState(() => ({ isShowFormatModal: false, formatStr: '' }));
- }
+ this.setState({ isShowFormatModal: false, formatStr: '' });
}
copyParams = () => {
@@ -62,68 +55,47 @@ class OpenRow extends React.Component {
this.hideFormatModal();
}
- componentDidMount() {
- this._isMounted = true;
- }
-
- componentWillUnmount() {
- this._isMounted = false;
- }
render() {
- const { trainingPlatform, record, logCollection, multiphase } = this.props;
const { isShowFormatModal, formatStr } = this.state;
- let isClick = false;
- let isHasParameters = true;
- if (record.description.parameters.error) {
- isHasParameters = false;
- }
- const openRowDataSource = record.description.parameters;
- const trialink: string = `${MANAGER_IP}/trial-jobs/${record.id}`;
- const logPathRow = record.description.logPath !== undefined
- ?
- record.description.logPath
- :
- 'This trial\'s log path are not available.';
+ const trialId = this.props.trialId;
+ const trial = TRIALS.getTrial(trialId);
+ const trialLink: string = `${MANAGER_IP}/trial-jobs/${trialId}`;
+ const logPathRow = trial.info.logPath || 'This trial\'s log path is not available.';
+ const multiProgress = trial.info.hyperParameters === undefined ? 0 : trial.info.hyperParameters.length;
return (
{
- multiphase
+ EXPERIMENT.multiPhase
?
Trails for multiphase experiment will return a set of parameters,
we are listing the latest parameter in webportal.
For the entire parameter set, please refer to the following "
- {trialink} ".
-
- Current Phase: {record.description.multiProgress}.
+ {trialLink} ".
+
+ Current Phase: {multiProgress}.
:
}
{
- isHasParameters
+ trial.info.hyperParameters !== undefined
?
- {
- isClick
- ?
- {JSON.stringify(openRowDataSource, null, 4)}
- :
- true} // default expandNode
- getItemString={() => ( )} // remove the {} items
- data={openRowDataSource}
- />
- }
+ true} // default expandNode
+ getItemString={() => ( )} // remove the {} items
+ data={trial.description.parameters}
+ />
Copy as json
@@ -138,15 +110,16 @@ class OpenRow extends React.Component {
{
- trainingPlatform !== 'local'
+ // FIXME: this should not be handled in web UI side
+ EXPERIMENT.trainingServicePlatform !== 'local'
?
:
-
+
}
@@ -170,4 +143,4 @@ class OpenRow extends React.Component {
}
}
-export default OpenRow;
\ No newline at end of file
+export default OpenRow;
diff --git a/src/webui/src/components/trial-detail/DefaultMetricPoint.tsx b/src/webui/src/components/trial-detail/DefaultMetricPoint.tsx
index 966df304ef..0938da1cdc 100644
--- a/src/webui/src/components/trial-detail/DefaultMetricPoint.tsx
+++ b/src/webui/src/components/trial-detail/DefaultMetricPoint.tsx
@@ -1,278 +1,41 @@
import * as React from 'react';
import { Switch } from 'antd';
import ReactEcharts from 'echarts-for-react';
-import { filterByStatus } from '../../static/function';
-import { TableObj, DetailAccurPoint, TooltipForAccuracy } from '../../static/interface';
+import { EXPERIMENT, TRIALS } from '../../static/datamodel';
+import { Trial } from '../../static/model/trial';
+import { TooltipForAccuracy } from '../../static/interface';
require('echarts/lib/chart/scatter');
require('echarts/lib/component/tooltip');
require('echarts/lib/component/title');
interface DefaultPointProps {
- showSource: Array;
- height: number;
- whichGraph: string;
- optimize: string;
+ trialIds: string[];
+ visible: boolean;
+ trialsUpdateBroadcast: number;
}
interface DefaultPointState {
- defaultSource: object;
- accNodata: string;
- succeedTrials: number;
- isViewBestCurve: boolean;
+ bestCurveEnabled: boolean;
}
class DefaultPoint extends React.Component {
- public _isDefaultMounted = false;
-
constructor(props: DefaultPointProps) {
super(props);
- this.state = {
- defaultSource: {},
- accNodata: '',
- succeedTrials: 10000000,
- isViewBestCurve: false
- };
- }
-
- defaultMetric = (succeedSource: Array, isCurve: boolean) => {
- const { optimize } = this.props;
- const accSource: Array = [];
- const drawSource: Array = succeedSource.filter(filterByStatus);
- const lengthOfSource = drawSource.length;
- const tooltipDefault = lengthOfSource === 0 ? 'No data' : '';
- if (this._isDefaultMounted === true) {
- this.setState(() => ({
- succeedTrials: lengthOfSource,
- accNodata: tooltipDefault
- }));
- }
- if (lengthOfSource === 0) {
- const nullGraph = {
- grid: {
- left: '8%'
- },
- xAxis: {
- name: 'Trial',
- type: 'category',
- },
- yAxis: {
- name: 'Default metric',
- type: 'value',
- }
- };
- if (this._isDefaultMounted === true) {
- this.setState(() => ({
- defaultSource: nullGraph
- }));
- }
- } else {
- const resultList: Array[] = [];
- // lineListDefault: [[sequenceId, default metric], []]
- const lineListDefault: Array[] = [];
- Object.keys(drawSource).map(item => {
- const temp = drawSource[item];
- if (temp.acc !== undefined) {
- if (temp.acc.default !== undefined) {
- const searchSpace = temp.description.parameters;
- lineListDefault.push([temp.sequenceId, temp.acc.default]);
- accSource.push({
- acc: temp.acc.default,
- index: temp.sequenceId,
- searchSpace: searchSpace
- });
- }
- }
- });
- // deal with best metric line
- const bestCurve: Array[] = []; // best curve data source
- if (lineListDefault[0] !== undefined) {
- bestCurve.push([lineListDefault[0][0], lineListDefault[0][1], accSource[0].searchSpace]);
- }
- if (optimize === 'maximize') {
- for (let i = 1; i < lineListDefault.length; i++) {
- const val = lineListDefault[i][1];
- const latest = bestCurve[bestCurve.length - 1][1];
- if (val >= latest) {
- bestCurve.push([lineListDefault[i][0], val, accSource[i].searchSpace]);
- } else {
- bestCurve.push([lineListDefault[i][0], latest, accSource[i].searchSpace]);
- }
- }
- } else {
- for (let i = 1; i < lineListDefault.length; i++) {
- const val = lineListDefault[i][1];
- const latest = bestCurve[bestCurve.length - 1][1];
- if (val <= latest) {
- bestCurve.push([lineListDefault[i][0], val, accSource[i].searchSpace]);
- } else {
- bestCurve.push([lineListDefault[i][0], latest, accSource[i].searchSpace]);
- }
- }
- }
- Object.keys(accSource).map(item => {
- const items = accSource[item];
- let temp: Array;
- temp = [items.index, items.acc, items.searchSpace];
- resultList.push(temp);
- });
- // isViewBestCurve: false show default metric graph
- // isViewBestCurve: true show best curve
- if (isCurve === true) {
- if (this._isDefaultMounted === true) {
- this.setState(() => ({
- defaultSource: this.drawBestcurve(bestCurve, resultList)
- }));
- }
- } else {
- if (this._isDefaultMounted === true) {
- this.setState(() => ({
- defaultSource: this.drawDefaultMetric(resultList)
- }));
- }
- }
- }
- }
-
- drawBestcurve = (realDefault: Array[], resultList: Array[]) => {
- return {
- grid: {
- left: '8%'
- },
- tooltip: {
- trigger: 'item',
- enterable: true,
- position: function (point: Array, data: TooltipForAccuracy) {
- if (data.data[0] < realDefault.length / 2) {
- return [point[0], 80];
- } else {
- return [point[0] - 300, 80];
- }
- },
- formatter: function (data: TooltipForAccuracy) {
- const result = '';
- return result;
- }
- },
- xAxis: {
- name: 'Trial',
- type: 'category',
- },
- yAxis: {
- name: 'Default metric',
- type: 'value',
- scale: true
- },
- series: [
- {
- type: 'line',
- lineStyle: { color: '#FF6600' },
- data: realDefault
- },
- {
- symbolSize: 6,
- type: 'scatter',
- data: resultList
- }]
- };
- }
-
- drawDefaultMetric = (resultList: Array[]) => {
- return {
- grid: {
- left: '8%'
- },
- tooltip: {
- trigger: 'item',
- enterable: true,
- position: function (point: Array, data: TooltipForAccuracy) {
- if (data.data[0] < resultList.length / 2) {
- return [point[0], 80];
- } else {
- return [point[0] - 300, 80];
- }
- },
- formatter: function (data: TooltipForAccuracy) {
- const result = '';
- return result;
- }
- },
- xAxis: {
- name: 'Trial',
- type: 'category',
- },
- yAxis: {
- name: 'Default metric',
- type: 'value',
- scale: true
- },
- series: [{
- symbolSize: 6,
- type: 'scatter',
- data: resultList
- }]
- };
+ this.state = { bestCurveEnabled: false };
}
loadDefault = (checked: boolean) => {
- // checked: true show best metric curve
- const { showSource } = this.props;
- if (this._isDefaultMounted === true) {
- this.defaultMetric(showSource, checked);
- // ** deal with data and then update view layer
- this.setState(() => ({ isViewBestCurve: checked }));
- }
- }
-
- // update parent component state
- componentWillReceiveProps(nextProps: DefaultPointProps) {
-
- const { whichGraph, showSource } = nextProps;
- const { isViewBestCurve } = this.state;
- if (whichGraph === '1') {
- this.defaultMetric(showSource, isViewBestCurve);
- }
+ this.setState({ bestCurveEnabled: checked });
}
shouldComponentUpdate(nextProps: DefaultPointProps, nextState: DefaultPointState) {
- const { whichGraph } = nextProps;
- if (whichGraph === '1') {
- const { succeedTrials, isViewBestCurve } = nextState;
- const succTrial = this.state.succeedTrials;
- const isViewBestCurveBefore = this.state.isViewBestCurve;
- if (isViewBestCurveBefore !== isViewBestCurve) {
- return true;
- }
- if (succeedTrials !== succTrial) {
- return true;
- }
- }
- // only whichGraph !== '1', default metric can't update
- return false;
- }
-
- componentDidMount() {
- this._isDefaultMounted = true;
- }
-
- componentWillUnmount() {
- this._isDefaultMounted = false;
+ return nextProps.visible;
}
render() {
- const { height } = this.props;
- const { defaultSource, accNodata } = this.state;
+ const graph = this.generateGraph();
+ const accNodata = (graph === EmptyGraph ? 'No data' : '');
+
return (
@@ -282,10 +45,10 @@ class DefaultPoint extends React.Component
);
}
+
+ private generateGraph() {
+ const trials = TRIALS.getTrials(this.props.trialIds).filter(trial => trial.sortable);
+ if (trials.length === 0) {
+ return EmptyGraph;
+ }
+ const graph = generateGraphConfig(trials[trials.length - 1].sequenceId);
+ if (this.state.bestCurveEnabled) {
+ (graph as any).series = [ generateBestCurveSeries(trials), generateScatterSeries(trials) ];
+ } else {
+ (graph as any).series = [ generateScatterSeries(trials) ];
+ }
+ return graph;
+ }
+}
+
+const EmptyGraph = {
+ grid: {
+ left: '8%'
+ },
+ xAxis: {
+ name: 'Trial',
+ type: 'category',
+ },
+ yAxis: {
+ name: 'Default metric',
+ type: 'value',
+ }
+};
+
+function generateGraphConfig(maxSequenceId: number) {
+ return {
+ grid: {
+ left: '8%',
+ },
+ tooltip: {
+ trigger: 'item',
+ enterable: true,
+ position: (point: Array, data: TooltipForAccuracy) => (
+ [ (data.data[0] < maxSequenceId ? point[0] : (point[0] - 300)), 80 ]
+ ),
+ formatter: (data: TooltipForAccuracy) => (
+ ''
+ ),
+ },
+ xAxis: {
+ name: 'Trial',
+ type: 'category',
+ },
+ yAxis: {
+ name: 'Default metric',
+ type: 'value',
+ scale: true,
+ },
+ series: undefined,
+ };
+}
+
+function generateScatterSeries(trials: Trial[]) {
+ const data = trials.map(trial => [
+ trial.sequenceId,
+ trial.accuracy,
+ trial.description.parameters,
+ ]);
+ return {
+ symbolSize: 6,
+ type: 'scatter',
+ data,
+ };
+}
+
+function generateBestCurveSeries(trials: Trial[]) {
+ let best = trials[0];
+ const data = [[ best.sequenceId, best.accuracy, best.info.hyperParameters ]];
+
+ for (let i = 1; i < trials.length; i++) {
+ const trial = trials[i];
+ const delta = trial.accuracy! - best.accuracy!;
+ const better = (EXPERIMENT.optimizeMode === 'minimize') ? (delta < 0) : (delta > 0);
+ if (better) {
+ data.push([ trial.sequenceId, trial.accuracy, trial.info.hyperParameters ]);
+ best = trial;
+ } else {
+ data.push([ trial.sequenceId, best.accuracy, trial.info.hyperParameters ]);
+ }
+ }
+
+ return {
+ type: 'line',
+ lineStyle: { color: '#FF6600' },
+ data,
+ };
}
-export default DefaultPoint;
\ No newline at end of file
+export default DefaultPoint;
diff --git a/src/webui/src/components/trial-detail/Duration.tsx b/src/webui/src/components/trial-detail/Duration.tsx
index d47b5107ce..c8add154b4 100644
--- a/src/webui/src/components/trial-detail/Duration.tsx
+++ b/src/webui/src/components/trial-detail/Duration.tsx
@@ -22,8 +22,6 @@ interface DurationState {
class Duration extends React.Component {
- public _isMounted = false;
-
constructor(props: DurationProps) {
super(props);
@@ -142,15 +140,12 @@ class Duration extends React.Component {
trialId: trialId,
trialTime: trialTime
});
- if (this._isMounted) {
- this.setState({
- durationSource: this.getOption(trialRun[0])
- });
- }
+ this.setState({
+ durationSource: this.getOption(trialRun[0])
+ });
}
componentDidMount() {
- this._isMounted = true;
const { source } = this.props;
this.drawDurationGraph(source);
}
@@ -187,10 +182,6 @@ class Duration extends React.Component {
return false;
}
- componentWillUnmount() {
- this._isMounted = false;
- }
-
render() {
const { durationSource } = this.state;
return (
@@ -206,4 +197,4 @@ class Duration extends React.Component {
}
}
-export default Duration;
\ No newline at end of file
+export default Duration;
diff --git a/src/webui/src/components/trial-detail/Intermediate.tsx b/src/webui/src/components/trial-detail/Intermediate.tsx
index 9a9dfa1e6d..3c24b8f497 100644
--- a/src/webui/src/components/trial-detail/Intermediate.tsx
+++ b/src/webui/src/components/trial-detail/Intermediate.tsx
@@ -24,7 +24,6 @@ interface IntermediateProps {
class Intermediate extends React.Component {
static intervalMediate = 1;
- public _isMounted = false;
public pointInput: HTMLInputElement | null;
public minValInput: HTMLInputElement | null;
public maxValInput: HTMLInputElement | null;
@@ -45,12 +44,10 @@ class Intermediate extends React.Component
drawIntermediate = (source: Array) => {
if (source.length > 0) {
- if (this._isMounted) {
- this.setState(() => ({
- length: source.length,
- detailSource: source
- }));
- }
+ this.setState({
+ length: source.length,
+ detailSource: source
+ });
const trialIntermediate: Array = [];
Object.keys(source).map(item => {
const temp = source[item];
@@ -118,11 +115,9 @@ class Intermediate extends React.Component
},
series: trialIntermediate
};
- if (this._isMounted) {
- this.setState(() => ({
- interSource: option
- }));
- }
+ this.setState({
+ interSource: option
+ });
} else {
const nullData = {
grid: {
@@ -139,71 +134,60 @@ class Intermediate extends React.Component
name: 'Metric'
}
};
- if (this._isMounted) {
- this.setState(() => ({ interSource: nullData }));
- }
+ this.setState({ interSource: nullData });
}
}
// confirm btn function [filter data]
filterLines = () => {
- if (this._isMounted) {
- const filterSource: Array = [];
- this.setState({ isLoadconfirmBtn: true }, () => {
- const { source } = this.props;
- // get input value
- const pointVal = this.pointInput !== null ? this.pointInput.value : '';
- const minVal = this.minValInput !== null ? this.minValInput.value : '';
- const maxVal = this.maxValInput !== null ? this.maxValInput.value : '';
- // user not input message
- if (pointVal === '' || minVal === '') {
- alert('Please input filter message');
+ const filterSource: Array = [];
+ this.setState({ isLoadconfirmBtn: true }, () => {
+ const { source } = this.props;
+ // get input value
+ const pointVal = this.pointInput !== null ? this.pointInput.value : '';
+ const minVal = this.minValInput !== null ? this.minValInput.value : '';
+ const maxVal = this.maxValInput !== null ? this.maxValInput.value : '';
+ // user not input message
+ if (pointVal === '' || minVal === '') {
+ alert('Please input filter message');
+ } else {
+ // user not input max value
+ const position = JSON.parse(pointVal);
+ const min = JSON.parse(minVal);
+ if (maxVal === '') {
+ Object.keys(source).map(item => {
+ const temp = source[item];
+ const val = temp.description.intermediate[position - 1];
+ if (val >= min) {
+ filterSource.push(temp);
+ }
+ });
} else {
- // user not input max value
- const position = JSON.parse(pointVal);
- const min = JSON.parse(minVal);
- if (maxVal === '') {
- Object.keys(source).map(item => {
- const temp = source[item];
- const val = temp.description.intermediate[position - 1];
- if (val >= min) {
- filterSource.push(temp);
- }
- });
- } else {
- const max = JSON.parse(maxVal);
- Object.keys(source).map(item => {
- const temp = source[item];
- const val = temp.description.intermediate[position - 1];
- if (val >= min && val <= max) {
- filterSource.push(temp);
- }
- });
- }
- if (this._isMounted) {
- this.setState({ filterSource: filterSource });
- }
- this.drawIntermediate(filterSource);
- }
- const counts = this.state.clickCounts + 1;
- if (this._isMounted) {
- this.setState({ isLoadconfirmBtn: false, clickCounts: counts });
+ const max = JSON.parse(maxVal);
+ Object.keys(source).map(item => {
+ const temp = source[item];
+ const val = temp.description.intermediate[position - 1];
+ if (val >= min && val <= max) {
+ filterSource.push(temp);
+ }
+ });
}
- });
- }
+ this.setState({ filterSource: filterSource });
+ this.drawIntermediate(filterSource);
+ }
+ const counts = this.state.clickCounts + 1;
+ this.setState({ isLoadconfirmBtn: false, clickCounts: counts });
+ });
}
switchTurn = (checked: boolean) => {
- if (this._isMounted) {
- this.setState({ isFilter: checked });
- }
+ this.setState({ isFilter: checked });
if (checked === false) {
this.drawIntermediate(this.props.source);
}
}
componentDidMount() {
- this._isMounted = true;
const { source } = this.props;
this.drawIntermediate(source);
}
@@ -272,10 +256,6 @@ class Intermediate extends React.Component
return false;
}
- componentWillUnmount() {
- this._isMounted = false;
- }
-
render() {
const { interSource, isLoadconfirmBtn, isFilter } = this.state;
return (
@@ -292,7 +272,7 @@ class Intermediate extends React.Component
isFilter
?
- # Intermediate
+ # Intermediate result
this.pointInput = input}
@@ -327,7 +307,7 @@ class Intermediate extends React.Component
style={{ width: '100%', height: 418, margin: '0 auto' }}
notMerge={true} // update now
/>
- # Intermediate
+ # Intermediate result
);
diff --git a/src/webui/src/components/trial-detail/Para.tsx b/src/webui/src/components/trial-detail/Para.tsx
index 522b691821..68b3814021 100644
--- a/src/webui/src/components/trial-detail/Para.tsx
+++ b/src/webui/src/components/trial-detail/Para.tsx
@@ -40,8 +40,6 @@ message.config({
class Para extends React.Component {
- public _isMounted = false;
-
private chartMulineStyle = {
width: '100%',
height: 392,
@@ -121,101 +119,157 @@ class Para extends React.Component {
this.swapGraph(paraData, swapAxisArr);
}
this.getOption(paraData, lengthofTrials);
- if (this._isMounted === true) {
- this.setState(() => ({ paraBack: paraData }));
- }
+ this.setState({ paraBack: paraData });
}
hyperParaPic = (source: Array, searchSpace: string) => {
// filter succeed trials [{}, {}, {}]
- const dataSource: Array = source.filter(filterByStatus);
+ const dataSource = source.filter(filterByStatus);
const lenOfDataSource: number = dataSource.length;
const accPara: Array = [];
// specific value array
const eachTrialParams: Array = [];
// experiment interface search space obj
const searchRange = searchSpace !== undefined ? JSON.parse(searchSpace) : '';
+ // nest search space
+ let isNested: boolean = false;
+ Object.keys(searchRange).map(item => {
+ if (searchRange[item]._value && typeof searchRange[item]._value[0] === 'object') {
+ isNested = true;
+ return;
+ }
+ });
const dimName = Object.keys(searchRange);
- if (this._isMounted === true) {
- this.setState(() => ({ dimName: dimName }));
- }
+ this.setState({ dimName: dimName });
const parallelAxis: Array = [];
// search space range and specific value [only number]
let i = 0;
- for (i; i < dimName.length; i++) {
- const searchKey = searchRange[dimName[i]];
- switch (searchKey._type) {
- case 'uniform':
- case 'quniform':
- parallelAxis.push({
- dim: i,
- name: dimName[i],
- max: searchKey._value[1],
- min: searchKey._value[0]
- });
- break;
-
- case 'randint':
- parallelAxis.push({
- dim: i,
- name: dimName[i],
- min: searchKey._value[0],
- max: searchKey._value[1],
- });
- break;
-
- case 'choice':
- const data: Array = [];
- for (let j = 0; j < searchKey._value.length; j++) {
- data.push(searchKey._value[j].toString());
- }
- parallelAxis.push({
- dim: i,
- name: dimName[i],
- type: 'category',
- data: data,
- boundaryGap: true,
- axisLine: {
- lineStyle: {
- type: 'dotted', // axis type,solid,dashed,dotted
- width: 1
- }
- },
- axisTick: {
- show: true,
- interval: 0,
- alignWithLabel: true,
- },
- axisLabel: {
- show: true,
- interval: 0,
- // rotate: 30
- },
- });
- break;
- // support log distribute
- case 'loguniform':
- if (lenOfDataSource > 1) {
+ if (isNested === false) {
+ for (i; i < dimName.length; i++) {
+ const searchKey = searchRange[dimName[i]];
+ switch (searchKey._type) {
+ case 'uniform':
+ case 'quniform':
parallelAxis.push({
dim: i,
name: dimName[i],
- type: 'log',
+ max: searchKey._value[1],
+ min: searchKey._value[0]
});
- } else {
+ break;
+ case 'randint':
+ parallelAxis.push({
+ dim: i,
+ name: dimName[i],
+ min: searchKey._value[0],
+ max: searchKey._value[1],
+ });
+ break;
+ case 'choice':
+ const data: Array = [];
+ for (let j = 0; j < searchKey._value.length; j++) {
+ data.push(searchKey._value[j].toString());
+ }
+ parallelAxis.push({
+ dim: i,
+ name: dimName[i],
+ type: 'category',
+ data: data,
+ boundaryGap: true,
+ axisLine: {
+ lineStyle: {
+ type: 'dotted', // axis type,solid,dashed,dotted
+ width: 1
+ }
+ },
+ axisTick: {
+ show: true,
+ interval: 0,
+ alignWithLabel: true,
+ },
+ axisLabel: {
+ show: true,
+ interval: 0,
+ // rotate: 30
+ },
+ });
+ break;
+ // support log distribute
+ case 'loguniform':
+ if (lenOfDataSource > 1) {
+ parallelAxis.push({
+ dim: i,
+ name: dimName[i],
+ type: 'log',
+ });
+ } else {
+ parallelAxis.push({
+ dim: i,
+ name: dimName[i]
+ });
+ }
+ break;
+ default:
parallelAxis.push({
dim: i,
name: dimName[i]
});
- }
- break;
-
- default:
- parallelAxis.push({
- dim: i,
- name: dimName[i]
- });
-
+ }
+ }
+ } else {
+ for (i; i < dimName.length; i++) {
+ const searchKey = searchRange[dimName[i]];
+ switch (searchKey._type) {
+ case 'choice':
+ const data: Array = [];
+ let j = 0;
+ for (j; j < searchKey._value.length; j++) {
+ const item = searchKey._value[j];
+ Object.keys(item).map(key => {
+ if (key !== '_name' && key !== '_type') {
+ Object.keys(item[key]).map(index => {
+ if (index !== '_type') {
+ const realChoice = item[key][index];
+ Object.keys(realChoice).map(m => {
+ data.push(`${item._name}_${realChoice[m]}`);
+ });
+ }
+ });
+ }
+ });
+ }
+ data.push('null');
+ parallelAxis.push({
+ dim: i,
+ name: dimName[i],
+ type: 'category',
+ data: data,
+ boundaryGap: true,
+ axisLine: {
+ lineStyle: {
+ type: 'dotted', // axis type,solid dashed dotted
+ width: 1
+ }
+ },
+ axisTick: {
+ show: true,
+ interval: 0,
+ alignWithLabel: true,
+ },
+ axisLabel: {
+ show: true,
+ interval: 0,
+ // rotate: 30
+ },
+ });
+ break;
+ default:
+ parallelAxis.push({
+ dim: i,
+ name: dimName[i]
+ });
+ }
}
}
parallelAxis.push({
@@ -263,34 +317,47 @@ class Para extends React.Component {
color: ['#CA0000', '#FFC400', '#90EE90']
}
};
- if (this._isMounted === true) {
- this.setState({
- paraNodata: 'No data',
- option: optionOfNull,
- sutrialCount: 0,
- succeedRenderCount: 0
- });
- }
+ this.setState({
+ paraNodata: 'No data',
+ option: optionOfNull,
+ sutrialCount: 0,
+ succeedRenderCount: 0
+ });
} else {
Object.keys(dataSource).map(item => {
- const temp = dataSource[item];
- eachTrialParams.push(temp.description.parameters);
+ const trial = dataSource[item];
+ eachTrialParams.push(trial.description.parameters || '');
// may be a succeed trial hasn't final result
// all detail page may be break down if havn't if
- if (temp.acc !== undefined) {
- if (temp.acc.default !== undefined) {
- accPara.push(temp.acc.default);
+ if (trial.acc !== undefined) {
+ if (trial.acc.default !== undefined) {
+ accPara.push(JSON.parse(trial.acc.default));
}
}
});
- if (this._isMounted) {
- // if not return final result
- const maxVal = accPara.length === 0 ? 1 : Math.max(...accPara);
- const minVal = accPara.length === 0 ? 1 : Math.min(...accPara);
- this.setState({ max: maxVal, min: minVal }, () => {
- this.getParallelAxis(dimName, parallelAxis, accPara, eachTrialParams, lenOfDataSource);
+ // nested search space, deal data
+ if (isNested !== false) {
+ eachTrialParams.forEach(element => {
+ Object.keys(element).forEach(key => {
+ let item = element[key];
+ if (typeof item === 'object') {
+ Object.keys(item).forEach(index => {
+ if (index !== '_name') {
+ element[key] = `${item._name}_${item[index]}`;
+ } else {
+ element[key] = 'null';
+ }
+ });
+ }
+ });
});
}
+ // if not return final result
+ const maxVal = accPara.length === 0 ? 1 : Math.max(...accPara);
+ const minVal = accPara.length === 0 ? 1 : Math.min(...accPara);
+ this.setState({ max: maxVal, min: minVal }, () => {
+ this.getParallelAxis(dimName, parallelAxis, accPara, eachTrialParams, lenOfDataSource);
+ });
}
}
@@ -298,11 +365,9 @@ class Para extends React.Component {
percentNum = (value: string) => {
let vals = parseFloat(value);
- if (this._isMounted) {
- this.setState({ percent: vals }, () => {
- this.reInit();
- });
- }
+ this.setState({ percent: vals }, () => {
+ this.reInit();
+ });
}
// deal with response data into pic data
@@ -367,22 +432,17 @@ class Para extends React.Component {
}
};
// please wait the data
- if (this._isMounted) {
- this.setState(() => ({
- option: optionown,
- paraNodata: '',
- succeedRenderCount: lengthofTrials,
- sutrialCount: paralleData.length
- }));
- }
+ this.setState({
+ option: optionown,
+ paraNodata: '',
+ succeedRenderCount: lengthofTrials,
+ sutrialCount: paralleData.length
+ });
}
// get swap parallel axis
getSwapArr = (value: Array) => {
-
- if (this._isMounted) {
- this.setState(() => ({ swapAxisArr: value }));
- }
+ this.setState({ swapAxisArr: value });
}
reInit = () => {
@@ -393,9 +453,7 @@ class Para extends React.Component {
swapReInit = () => {
const { clickCounts, succeedRenderCount } = this.state;
const val = clickCounts + 1;
- if (this._isMounted) {
- this.setState({ isLoadConfirm: true, clickCounts: val, });
- }
+ this.setState({ isLoadConfirm: true, clickCounts: val, });
const { paraBack, swapAxisArr } = this.state;
const paralDim = paraBack.parallelAxis;
const paraData = paraBack.data;
@@ -445,11 +503,9 @@ class Para extends React.Component {
});
this.getOption(paraBack, succeedRenderCount);
// please wait the data
- if (this._isMounted) {
- this.setState(() => ({
- isLoadConfirm: false
- }));
- }
+ this.setState({
+ isLoadConfirm: false
+ });
}
sortDimY = (a: Dimobj, b: Dimobj) => {
@@ -507,7 +563,6 @@ class Para extends React.Component {
}
componentDidMount() {
- this._isMounted = true;
this.reInit();
}
@@ -545,10 +600,6 @@ class Para extends React.Component {
return false;
}
- componentWillUnmount() {
- this._isMounted = false;
- }
-
render() {
const { option, paraNodata, dimName, isLoadConfirm } = this.state;
return (
@@ -609,4 +660,4 @@ class Para extends React.Component {
}
}
-export default Para;
\ No newline at end of file
+export default Para;
diff --git a/src/webui/src/components/trial-detail/TableList.tsx b/src/webui/src/components/trial-detail/TableList.tsx
index a1af9be602..2048e91fc0 100644
--- a/src/webui/src/components/trial-detail/TableList.tsx
+++ b/src/webui/src/components/trial-detail/TableList.tsx
@@ -2,14 +2,15 @@ import * as React from 'react';
import axios from 'axios';
import ReactEcharts from 'echarts-for-react';
import { Row, Table, Button, Popconfirm, Modal, Checkbox, Select, Icon } from 'antd';
+import { ColumnProps } from 'antd/lib/table';
const Option = Select.Option;
const CheckboxGroup = Checkbox.Group;
import { MANAGER_IP, trialJobStatus, COLUMN_INDEX, COLUMNPro } from '../../static/const';
-import { convertDuration, intermediateGraphOption, killJob, filterByStatus } from '../../static/function';
-import { TableObj, TrialJob } from '../../static/interface';
+import { convertDuration, formatTimestamp, intermediateGraphOption, killJob } from '../../static/function';
+import { TRIALS } from '../../static/datamodel';
+import { TableRecord } from '../../static/interface';
import OpenRow from '../public-child/OpenRow';
import Compare from '../Modal/Compare';
-import IntermediateVal from '../public-child/IntermediateVal'; // table default metric column
import '../../static/style/search.scss';
require('../../static/style/tableStatus.css');
require('../../static/style/logPath.scss');
@@ -26,14 +27,11 @@ echarts.registerTheme('my_theme', {
});
interface TableListProps {
- entries: number;
- tableSource: Array;
- updateList: Function;
- platform: string;
- logCollection: boolean;
- isMultiPhase: boolean;
+ pageSize: number;
+ tableSource: Array;
columnList: Array; // user select columnKeys
changeColumn: (val: Array) => void;
+ trialsUpdateBroadcast: number;
}
interface TableListState {
@@ -41,7 +39,7 @@ interface TableListState {
modalVisible: boolean;
isObjFinal: boolean;
isShowColumn: boolean;
- selectRows: Array;
+ selectRows: Array;
isShowCompareModal: boolean;
selectedRowKeys: string[] | number[];
intermediateData: Array; // a trial's intermediate results (include dict)
@@ -56,10 +54,9 @@ interface ColumnIndex {
class TableList extends React.Component {
- public _isMounted = false;
public intervalTrialLog = 10;
public _trialId: string;
- public tables: Table | null;
+ public tables: Table | null;
constructor(props: TableListProps) {
super(props);
@@ -78,46 +75,35 @@ class TableList extends React.Component {
};
}
- showIntermediateModal = (id: string) => {
-
- axios(`${MANAGER_IP}/metric-data/${id}`, {
- method: 'GET'
- })
- .then(res => {
- if (res.status === 200) {
- const intermediateArr: number[] = [];
- // support intermediate result is dict because the last intermediate result is
- // final result in a succeed trial, it may be a dict.
- // get intermediate result dict keys array
- let otherkeys: Array = ['default'];
- if (res.data.length !== 0) {
- otherkeys = Object.keys(JSON.parse(res.data[0].data));
- }
- // intermediateArr just store default val
- Object.keys(res.data).map(item => {
- const temp = JSON.parse(res.data[item].data);
- if (typeof temp === 'object') {
- intermediateArr.push(temp.default);
- } else {
- intermediateArr.push(temp);
- }
- });
- const intermediate = intermediateGraphOption(intermediateArr, id);
- if (this._isMounted) {
- this.setState(() => ({
- intermediateData: res.data, // store origin intermediate data for a trial
- intermediateOption: intermediate,
- intermediateOtherKeys: otherkeys,
- intermediateId: id
- }));
- }
+ showIntermediateModal = async (id: string) => {
+ const res = await axios.get(`${MANAGER_IP}/metric-data/${id}`);
+ if (res.status === 200) {
+ const intermediateArr: number[] = [];
+ // support intermediate result is dict because the last intermediate result is
+ // final result in a succeed trial, it may be a dict.
+ // get intermediate result dict keys array
+ let otherkeys: Array = ['default'];
+ if (res.data.length !== 0) {
+ otherkeys = Object.keys(JSON.parse(res.data[0].data));
+ }
+ // intermediateArr just store default val
+ Object.keys(res.data).map(item => {
+ const temp = JSON.parse(res.data[item].data);
+ if (typeof temp === 'object') {
+ intermediateArr.push(temp.default);
+ } else {
+ intermediateArr.push(temp);
}
});
- if (this._isMounted) {
+ const intermediate = intermediateGraphOption(intermediateArr, id);
this.setState({
- modalVisible: true
+ intermediateData: res.data, // store origin intermediate data for a trial
+ intermediateOption: intermediate,
+ intermediateOtherKeys: otherkeys,
+ intermediateId: id
});
}
+ this.setState({ modalVisible: true });
}
// intermediate button click -> intermediate graph for each trial
@@ -147,43 +133,36 @@ class TableList extends React.Component {
}
const intermediate = intermediateGraphOption(intermediateArr, intermediateId);
// re-render
- if (this._isMounted) {
- this.setState(() => ({
- intermediateOption: intermediate
- }));
- }
+ this.setState({
+ intermediateOption: intermediate
+ });
}
hideIntermediateModal = () => {
- if (this._isMounted) {
- this.setState({
- modalVisible: false
- });
- }
+ this.setState({
+ modalVisible: false
+ });
}
hideShowColumnModal = () => {
- if (this._isMounted) {
- this.setState({
- isShowColumn: false
- });
- }
+ this.setState({
+ isShowColumn: false
+ });
}
// click add column btn, just show the modal of addcolumn
addColumn = () => {
// show user select check button
- if (this._isMounted) {
- this.setState({
- isShowColumn: true
- });
- }
+ this.setState({
+ isShowColumn: true
+ });
}
// checkbox for coloumn
selectedColumn = (checkedValues: Array) => {
- // 7: because have seven common column, "Intermediate count" is not shown by default
- let count = 7;
+ // 9: because have nine common column,
+ // [Intermediate count, Start Time, End Time] is hidden by default
+ let count = 9;
const want: Array = [];
const finalKeys: Array = [];
const wantResult: Array = [];
@@ -191,11 +170,13 @@ class TableList extends React.Component {
switch (checkedValues[m]) {
case 'Trial No.':
case 'ID':
+ case 'Start Time':
+ case 'End Time':
case 'Duration':
case 'Status':
case 'Operation':
case 'Default':
- case 'Intermediate count':
+ case 'Intermediate result':
break;
default:
finalKeys.push(checkedValues[m]);
@@ -227,27 +208,17 @@ class TableList extends React.Component {
wantResult.push(want[i].name);
});
- if (this._isMounted) {
- this.props.changeColumn(wantResult);
- }
+ this.props.changeColumn(wantResult);
}
- openRow = (record: TableObj) => {
- const { platform, logCollection, isMultiPhase } = this.props;
+ openRow = (record: TableRecord) => {
return (
-
+
);
}
- fillSelectedRowsTostate = (selected: number[] | string[], selectedRows: Array) => {
- if (this._isMounted === true) {
- this.setState(() => ({ selectRows: selectedRows, selectedRowKeys: selected }));
- }
+ fillSelectedRowsTostate = (selected: number[] | string[], selectedRows: Array) => {
+ this.setState({ selectRows: selectedRows, selectedRowKeys: selected });
}
// open Compare-modal
compareBtn = () => {
@@ -256,178 +227,87 @@ class TableList extends React.Component {
if (selectRows.length === 0) {
alert('Please select datas you want to compare!');
} else {
- if (this._isMounted === true) {
- this.setState({ isShowCompareModal: true });
- }
+ this.setState({ isShowCompareModal: true });
}
}
// close Compare-modal
hideCompareModal = () => {
// close modal. clear select rows data, clear selected track
- if (this._isMounted) {
- this.setState({ isShowCompareModal: false, selectedRowKeys: [], selectRows: [] });
- }
- }
-
- componentDidMount() {
- this._isMounted = true;
- }
-
- componentWillUnmount() {
- this._isMounted = false;
+ this.setState({ isShowCompareModal: false, selectedRowKeys: [], selectRows: [] });
}
render() {
-
- const { entries, tableSource, updateList, columnList } = this.props;
+ const { pageSize, columnList } = this.props;
+ const tableSource: Array = JSON.parse(JSON.stringify(this.props.tableSource));
const { intermediateOption, modalVisible, isShowColumn,
selectRows, isShowCompareModal, selectedRowKeys, intermediateOtherKeys } = this.state;
const rowSelection = {
selectedRowKeys: selectedRowKeys,
- onChange: (selected: string[] | number[], selectedRows: Array) => {
+ onChange: (selected: string[] | number[], selectedRows: Array) => {
this.fillSelectedRowsTostate(selected, selectedRows);
}
};
let showTitle = COLUMNPro;
- let bgColor = '';
- const trialJob: Array = [];
const showColumn: Array = [];
+
+ // parameter as table column
+ const trialMess = TRIALS.getTrial(tableSource[0].id);
+ const trial = trialMess.description.parameters;
+ const parameterColumn: Array = Object.keys(trial);
+ const parameterStr: Array = [];
+ parameterColumn.forEach(value => {
+ parameterStr.push(`${value} (search space)`);
+ });
+ showTitle = COLUMNPro.concat(parameterStr);
+
// only succeed trials have final keys
- if (tableSource.filter(filterByStatus).length >= 1) {
- const temp = tableSource.filter(filterByStatus)[0].acc;
+ if (tableSource.filter(record => record.status === 'SUCCEEDED').length >= 1) {
+ const temp = tableSource.filter(record => record.status === 'SUCCEEDED')[0].accuracy;
if (temp !== undefined && typeof temp === 'object') {
- if (this._isMounted) {
- // concat default column and finalkeys
- const item = Object.keys(temp);
- // item: ['default', 'other-keys', 'maybe loss']
- if (item.length > 1) {
- const want: Array = [];
- item.forEach(value => {
- if (value !== 'default') {
- want.push(value);
- }
- });
- showTitle = COLUMNPro.concat(want);
- }
+ // concat default column and finalkeys
+ const item = Object.keys(temp);
+ // item: ['default', 'other-keys', 'maybe loss']
+ if (item.length > 1) {
+ const want: Array = [];
+ item.forEach(value => {
+ if (value !== 'default') {
+ want.push(value);
+ }
+ });
+ showTitle = COLUMNPro.concat(want);
}
}
}
- trialJobStatus.map(item => {
- trialJob.push({
- text: item,
- value: item
- });
- });
- Object.keys(columnList).map(key => {
- const item = columnList[key];
+ for (const item of columnList) {
+ const paraColumn = item.match(/ \(search space\)$/);
+ let cc;
+ if (paraColumn !== null) {
+ cc = paraColumn.input;
+ }
switch (item) {
case 'Trial No.':
- showColumn.push({
- title: 'Trial No.',
- dataIndex: 'sequenceId',
- key: 'sequenceId',
- width: 120,
- className: 'tableHead',
- sorter: (a: TableObj, b: TableObj) => (a.sequenceId as number) - (b.sequenceId as number)
- });
+ showColumn.push(SequenceIdColumnConfig);
break;
case 'ID':
- showColumn.push({
- title: 'ID',
- dataIndex: 'id',
- key: 'id',
- width: 60,
- className: 'tableHead leftTitle',
- // the sort of string
- sorter: (a: TableObj, b: TableObj): number => a.id.localeCompare(b.id),
- render: (text: string, record: TableObj) => {
- return (
- {record.id}
- );
- }
- });
+ showColumn.push(IdColumnConfig);
+ break;
+ case 'Start Time':
+ showColumn.push(StartTimeColumnConfig);
+ break;
+ case 'End Time':
+ showColumn.push(EndTimeColumnConfig);
break;
case 'Duration':
- showColumn.push({
- title: 'Duration',
- dataIndex: 'duration',
- key: 'duration',
- width: 100,
- // the sort of number
- sorter: (a: TableObj, b: TableObj) => (a.duration as number) - (b.duration as number),
- render: (text: string, record: TableObj) => {
- let duration;
- if (record.duration !== undefined) {
- // duration is nagative number(-1) & 0-1
- if (record.duration > 0 && record.duration < 1 || record.duration < 0) {
- duration = `${record.duration}s`;
- } else {
- duration = convertDuration(record.duration);
- }
- } else {
- duration = 0;
- }
- return (
-
- );
- },
- });
+ showColumn.push(DurationColumnConfig);
break;
case 'Status':
- showColumn.push({
- title: 'Status',
- dataIndex: 'status',
- key: 'status',
- width: 150,
- className: 'tableStatus',
- render: (text: string, record: TableObj) => {
- bgColor = record.status;
- return (
- {record.status}
- );
- },
- filters: trialJob,
- onFilter: (value: string, record: TableObj) => {
- return record.status.indexOf(value) === 0;
- },
- // onFilter: (value: string, record: TableObj) => record.status.indexOf(value) === 0,
- sorter: (a: TableObj, b: TableObj): number => a.status.localeCompare(b.status)
- });
+ showColumn.push(StatusColumnConfig);
break;
- case 'Intermediate count':
- showColumn.push({
- title: 'Intermediate count',
- dataIndex: 'progress',
- key: 'progress',
- width: 86,
- render: (text: string, record: TableObj) => {
- return (
- {`#${record.description.intermediate.length}`}
- );
- },
- });
+ case 'Intermediate result':
+ showColumn.push(IntermediateCountColumnConfig);
break;
case 'Default':
- showColumn.push({
- title: 'Default metric',
- className: 'leftTitle',
- dataIndex: 'acc',
- key: 'acc',
- width: 120,
- sorter: (a: TableObj, b: TableObj) => {
- const oneArr = a.description.intermediate;
- const otherArr = b.description.intermediate;
- const one = (oneArr[oneArr.length - 1] !== undefined) ? oneArr[oneArr.length - 1] : 0;
- const other = (otherArr[otherArr.length - 1] !== undefined)
- ? otherArr[otherArr.length - 1] : 0;
- return one - other;
- },
- render: (text: string, record: TableObj) => {
- return (
-
- );
- }
- });
+ showColumn.push(AccuracyColumnConfig);
break;
case 'Operation':
showColumn.push({
@@ -435,7 +315,7 @@ class TableList extends React.Component {
dataIndex: 'operation',
key: 'operation',
width: 120,
- render: (text: string, record: TableObj) => {
+ render: (text: string, record: TableRecord) => {
let trialStatus = record.status;
const flag: boolean = (trialStatus === 'RUNNING') ? false : true;
return (
@@ -453,7 +333,7 @@ class TableList extends React.Component {
{
},
});
break;
-
- case 'Intermediate result':
+ case (cc):
+ // remove SEARCH_SPACE title
+ const realItem = item.replace(' (search space)', '');
showColumn.push({
- title: 'Intermediate result',
- dataIndex: 'intermediate',
- key: 'intermediate',
- width: '16%',
- render: (text: string, record: TableObj) => {
+ title: realItem,
+ dataIndex: item,
+ key: item,
+ width: '6%',
+ render: (text: string, record: TableRecord) => {
+ const eachTrial = TRIALS.getTrial(record.id);
return (
-
- Intermediate
-
+ {eachTrial.description.parameters[realItem]}
);
},
});
break;
default:
- showColumn.push({
- title: item,
- dataIndex: item,
- key: item,
- width: 150,
- render: (text: string, record: TableObj) => {
- const temp = record.acc;
- let decimals = 0;
- let other = '';
- if (temp !== undefined) {
- if (temp[item].toString().indexOf('.') !== -1) {
- decimals = temp[item].toString().length - temp[item].toString().indexOf('.') - 1;
- if (decimals > 6) {
- other = `${temp[item].toFixed(6)}`;
- } else {
- other = temp[item].toString();
- }
- }
- } else {
- other = '--';
- }
- return (
- {other}
- );
- }
- });
+ // FIXME
+ alert('Unexpected column type');
}
- });
+ }
return (
| null) => this.tables = table}
+ ref={(table: Table | null) => this.tables = table}
columns={showColumn}
rowSelection={rowSelection}
expandedRowRender={this.openRow}
dataSource={tableSource}
className="commonTableStyle"
- pagination={{ pageSize: entries }}
+ scroll={{x: 'max-content'}}
+ pagination={pageSize > 0 ? { pageSize } : false}
/>
{/* Intermediate Result Modal */}
{
}
}
+const SequenceIdColumnConfig: ColumnProps = {
+ title: 'Trial No.',
+ dataIndex: 'sequenceId',
+ width: 120,
+ className: 'tableHead',
+ sorter: (a, b) => a.sequenceId - b.sequenceId
+};
+
+const IdColumnConfig: ColumnProps = {
+ title: 'ID',
+ dataIndex: 'id',
+ width: 60,
+ className: 'tableHead leftTitle',
+ sorter: (a, b) => a.id.localeCompare(b.id),
+ render: (text, record) => (
+ {record.id}
+ )
+};
+
+const StartTimeColumnConfig: ColumnProps = {
+ title: 'Start Time',
+ dataIndex: 'startTime',
+ width: 160,
+ render: (text, record) => (
+ {formatTimestamp(record.startTime)}
+ )
+};
+
+const EndTimeColumnConfig: ColumnProps = {
+ title: 'End Time',
+ dataIndex: 'endTime',
+ width: 160,
+ render: (text, record) => (
+ {formatTimestamp(record.endTime, '--')}
+ )
+};
+
+const DurationColumnConfig: ColumnProps = {
+ title: 'Duration',
+ dataIndex: 'duration',
+ width: 100,
+ sorter: (a, b) => a.duration - b.duration,
+ render: (text, record) => (
+ {convertDuration(record.duration)}
+ )
+};
+
+const StatusColumnConfig: ColumnProps = {
+ title: 'Status',
+ dataIndex: 'status',
+ width: 150,
+ className: 'tableStatus',
+ render: (text, record) => (
+ {record.status}
+ ),
+ sorter: (a, b) => a.status.localeCompare(b.status),
+ filters: trialJobStatus.map(status => ({ text: status, value: status })),
+ onFilter: (value, record) => (record.status === value)
+};
+
+const IntermediateCountColumnConfig: ColumnProps = {
+ title: 'Intermediate result',
+ dataIndex: 'intermediateCount',
+ width: 86,
+ render: (text, record) => (
+ {`#${record.intermediateCount}`}
+ )
+};
+
+const AccuracyColumnConfig: ColumnProps = {
+ title: 'Default metric',
+ className: 'leftTitle',
+ dataIndex: 'accuracy',
+ width: 120,
+ sorter: (a, b, sortOrder) => {
+ if (a.accuracy === undefined) {
+ return sortOrder === 'ascend' ? -1 : 1;
+ } else if (b.accuracy === undefined) {
+ return sortOrder === 'ascend' ? 1 : -1;
+ } else {
+ return a.accuracy - b.accuracy;
+ }
+ },
+ render: (text, record) => (
+ // TODO: is this needed?
+ {record.latestAccuracy}
+ )
+};
+
export default TableList;
diff --git a/src/webui/src/static/const.ts b/src/webui/src/static/const.ts
index f2708ebbee..368daa624c 100644
--- a/src/webui/src/static/const.ts
+++ b/src/webui/src/static/const.ts
@@ -1,3 +1,7 @@
+// when there are more trials than this threshold, metrics will be updated in group of this size to avoid freezing
+const METRIC_GROUP_UPDATE_THRESHOLD = 100;
+const METRIC_GROUP_UPDATE_SIZE = 20;
+
const MANAGER_IP = `/api/v1/nni`;
const DOWNLOAD_IP = `/logs`;
const trialJobStatus = [
@@ -34,21 +38,29 @@ const COLUMN_INDEX = [
index: 2
},
{
- name: 'Duration',
+ name: 'Start Time',
index: 3
},
{
- name: 'Status',
+ name: 'End Time',
index: 4
},
{
- name: 'Intermediate count',
+ name: 'Duration',
index: 5
},
{
- name: 'Default',
+ name: 'Status',
index: 6
},
+ {
+ name: 'Intermediate result',
+ index: 7
+ },
+ {
+ name: 'Default',
+ index: 8
+ },
{
name: 'Operation',
index: 10000
@@ -57,8 +69,10 @@ const COLUMN_INDEX = [
// defatult selected column
const COLUMN = ['Trial No.', 'ID', 'Duration', 'Status', 'Default', 'Operation'];
// all choice column !dictory final
-const COLUMNPro = ['Trial No.', 'ID', 'Duration', 'Status', 'Intermediate count', 'Default', 'Operation'];
+const COLUMNPro = ['Trial No.', 'ID', 'Start Time', 'End Time', 'Duration', 'Status',
+'Intermediate result', 'Default', 'Operation'];
export {
MANAGER_IP, DOWNLOAD_IP, trialJobStatus, COLUMNPro,
- CONTROLTYPE, MONACO, COLUMN, COLUMN_INDEX, DRAWEROPTION
+ CONTROLTYPE, MONACO, COLUMN, COLUMN_INDEX, DRAWEROPTION,
+ METRIC_GROUP_UPDATE_THRESHOLD, METRIC_GROUP_UPDATE_SIZE,
};
diff --git a/src/webui/src/static/datamodel.ts b/src/webui/src/static/datamodel.ts
new file mode 100644
index 0000000000..b47f7b2e84
--- /dev/null
+++ b/src/webui/src/static/datamodel.ts
@@ -0,0 +1,7 @@
+import { Experiment } from './model/experiment';
+import { TrialManager } from './model/trialmanager';
+
+const EXPERIMENT = new Experiment();
+const TRIALS = new TrialManager();
+
+export { EXPERIMENT, TRIALS };
diff --git a/src/webui/src/static/function.ts b/src/webui/src/static/function.ts
index 857c2fb0d4..be352ff35c 100644
--- a/src/webui/src/static/function.ts
+++ b/src/webui/src/static/function.ts
@@ -1,9 +1,12 @@
import axios from 'axios';
import { message } from 'antd';
import { MANAGER_IP } from './const';
-import { FinalResult, FinalType, TableObj } from './interface';
+import { MetricDataRecord, FinalType, TableObj } from './interface';
const convertTime = (num: number) => {
+ if (num <= 0) {
+ return '0';
+ }
if (num % 3600 === 0) {
return num / 3600 + 'h';
} else {
@@ -15,24 +18,28 @@ const convertTime = (num: number) => {
// trial's duration, accurate to seconds for example 10min 30s
const convertDuration = (num: number) => {
+ if (num < 1) {
+ return '0s';
+ }
const hour = Math.floor(num / 3600);
- const min = Math.floor(num / 60 % 60);
+ const minute = Math.floor(num / 60 % 60);
const second = Math.floor(num % 60);
- const result = hour > 0 ? `${hour} h ${min} min ${second}s` : `${min} min ${second}s`;
- if (hour <= 0 && min === 0 && second !== 0) {
- return `${second}s`;
- } else if (hour === 0 && min !== 0 && second === 0) {
- return `${min}min`;
- } else if (hour === 0 && min !== 0 && second !== 0) {
- return `${min}min ${second}s`;
- } else {
- return result;
+ let result = [ ];
+ if (hour > 0) {
+ result.push(`${hour}h`);
+ }
+ if (minute > 0) {
+ result.push(`${minute}min`);
+ }
+ if (second > 0) {
+ result.push(`${second}s`);
}
+ return result.join(' ');
};
// get final result value
// draw Accuracy point graph
-const getFinalResult = (final: Array) => {
+const getFinalResult = (final?: MetricDataRecord[]) => {
let acc;
let showDefault = 0;
if (final) {
@@ -51,7 +58,7 @@ const getFinalResult = (final: Array) => {
};
// get final result value // acc obj
-const getFinal = (final: Array) => {
+const getFinal = (final?: MetricDataRecord[]) => {
let showDefault: FinalType;
if (final) {
showDefault = JSON.parse(final[final.length - 1].data);
@@ -101,7 +108,7 @@ const intermediateGraphOption = (intermediateArr: number[], id: string) => {
};
// kill job
-const killJob = (key: number, id: string, status: string, updateList: Function) => {
+const killJob = (key: number, id: string, status: string, updateList?: Function) => {
axios(`${MANAGER_IP}/trial-jobs/${id}`, {
method: 'DELETE',
headers: {
@@ -113,7 +120,9 @@ const killJob = (key: number, id: string, status: string, updateList: Function)
message.destroy();
message.success('Cancel the job successfully');
// render the table
- updateList();
+ if (updateList) {
+ updateList(); // FIXME
+ }
} else {
message.error('fail to cancel the job');
}
@@ -160,7 +169,22 @@ const downFile = (content: string, fileName: string) => {
}
};
+function formatTimestamp(timestamp?: number, placeholder?: string = 'N/A'): string {
+ return timestamp ? new Date(timestamp).toLocaleString('en-US') : placeholder;
+}
+
+function metricAccuracy(metric: MetricDataRecord): number {
+ const data = JSON.parse(metric.data);
+ return typeof data === 'number' ? data : NaN;
+}
+
+function formatAccuracy(accuracy: number): string {
+ // TODO: how to format NaN?
+ return accuracy.toFixed(6).replace(/0+$/, '').replace(/\.$/, '');
+}
+
export {
convertTime, convertDuration, getFinalResult, getFinal, downFile,
- intermediateGraphOption, killJob, filterByStatus, filterDuration
+ intermediateGraphOption, killJob, filterByStatus, filterDuration,
+ formatAccuracy, formatTimestamp, metricAccuracy,
};
diff --git a/src/webui/src/static/interface.ts b/src/webui/src/static/interface.ts
index 15349e23ab..44789ab7f0 100644
--- a/src/webui/src/static/interface.ts
+++ b/src/webui/src/static/interface.ts
@@ -1,3 +1,5 @@
+// tslint:disable:no-any
+
// draw accuracy graph data interface
interface TableObj {
key: number;
@@ -8,6 +10,21 @@ interface TableObj {
acc?: FinalType; // draw accuracy graph
description: Parameters;
color?: string;
+ startTime?: number;
+ endTime?: number;
+}
+
+interface TableRecord {
+ key: string;
+ sequenceId: number;
+ startTime: number;
+ endTime?: number;
+ id: string;
+ duration: number;
+ status: string;
+ intermediateCount: number;
+ accuracy?: number;
+ latestAccuracy: string; // formatted string
}
interface SearchSpace {
@@ -30,26 +47,6 @@ interface Parameters {
multiProgress?: number;
}
-interface Experiment {
- id: string;
- author: string;
- revision?: number;
- experName: string;
- logDir?: string;
- runConcurren: number;
- maxDuration: number;
- execDuration: number;
- MaxTrialNum: number;
- startTime: number;
- endTime?: number;
- trainingServicePlatform: string;
- tuner: object;
- assessor?: object;
- advisor?: object;
- clusterMetaData?: object;
- logCollection?: string;
-}
-
// trial accuracy
interface AccurPoint {
acc: number;
@@ -72,21 +69,6 @@ interface TooltipForAccuracy {
data: Array;
}
-interface TrialNumber {
- succTrial: number;
- failTrial: number;
- stopTrial: number;
- waitTrial: number;
- runTrial: number;
- unknowTrial: number;
- totalCurrentTrial: number;
-}
-
-interface TrialJob {
- text: string;
- value: string;
-}
-
interface Dimobj {
dim: number;
name: string;
@@ -106,10 +88,6 @@ interface ParaObj {
parallelAxis: Array;
}
-interface FinalResult {
- data: string;
-}
-
interface Intermedia {
name: string; // id
type: string;
@@ -117,13 +95,93 @@ interface Intermedia {
hyperPara: object; // each trial hyperpara value
}
-interface ExperimentInfo {
- platform: string;
- optimizeMode: string;
+interface MetricDataRecord {
+ timestamp: number;
+ trialJobId: string;
+ parameterId: string;
+ type: string;
+ sequence: number;
+ data: string;
+}
+
+interface TrialJobInfo {
+ id: string;
+ sequenceId: number;
+ status: string;
+ startTime?: number;
+ endTime?: number;
+ hyperParameters?: string[];
+ logPath?: string;
+ finalMetricData?: MetricDataRecord[];
+ stderrPath?: string;
+}
+
+interface ExperimentParams {
+ authorName: string;
+ experimentName: string;
+ description?: string;
+ trialConcurrency: number;
+ maxExecDuration: number; // seconds
+ maxTrialNum: number;
+ searchSpace: string;
+ trainingServicePlatform: string;
+ multiPhase?: boolean;
+ multiThread?: boolean;
+ versionCheck?: boolean;
+ logCollection?: string;
+ tuner?: {
+ className: string;
+ builtinTunerName?: string;
+ codeDir?: string;
+ classArgs?: any;
+ classFileName?: string;
+ checkpointDir: string;
+ gpuNum?: number;
+ includeIntermediateResults?: boolean;
+ };
+ assessor?: {
+ className: string;
+ builtinAssessorName?: string;
+ codeDir?: string;
+ classArgs?: any;
+ classFileName?: string;
+ checkpointDir: string;
+ gpuNum?: number;
+ };
+ advisor?: {
+ className: string;
+ builtinAdvisorName?: string;
+ codeDir?: string;
+ classArgs?: any;
+ classFileName?: string;
+ checkpointDir: string;
+ gpuNum?: number;
+ };
+ clusterMetaData?: {
+ key: string;
+ value: string;
+ }[];
+}
+
+interface ExperimentProfile {
+ params: ExperimentParams;
+ id: string;
+ execDuration: number;
+ logDir?: string;
+ startTime?: number;
+ endTime?: number;
+ maxSequenceId: number;
+ revision: number;
+}
+
+interface NNIManagerStatus {
+ status: string;
+ errors: string[];
}
export {
- TableObj, Parameters, Experiment, AccurPoint, TrialNumber, TrialJob,
- DetailAccurPoint, TooltipForAccuracy, ParaObj, Dimobj, FinalResult, FinalType,
- TooltipForIntermediate, SearchSpace, Intermedia, ExperimentInfo
+ TableObj, TableRecord, Parameters, ExperimentProfile, AccurPoint,
+ DetailAccurPoint, TooltipForAccuracy, ParaObj, Dimobj, FinalType,
+ TooltipForIntermediate, SearchSpace, Intermedia, MetricDataRecord, TrialJobInfo,
+ NNIManagerStatus,
};
diff --git a/src/webui/src/static/model/experiment.ts b/src/webui/src/static/model/experiment.ts
new file mode 100644
index 0000000000..e5e751c5c6
--- /dev/null
+++ b/src/webui/src/static/model/experiment.ts
@@ -0,0 +1,87 @@
+import axios from 'axios';
+import { MANAGER_IP } from '../const';
+import { ExperimentProfile, NNIManagerStatus } from '../interface';
+
+function compareProfiles(profile1?: ExperimentProfile, profile2?: ExperimentProfile): boolean {
+ if (!profile1 || !profile2) {
+ return false;
+ }
+ const copy1 = Object.assign({}, profile1, { execDuration: undefined });
+ const copy2 = Object.assign({}, profile2, { execDuration: undefined });
+ return JSON.stringify(copy1) === JSON.stringify(copy2);
+}
+
+class Experiment {
+ private profileField?: ExperimentProfile = undefined;
+ private statusField?: NNIManagerStatus = undefined;
+
+ public async init(): Promise {
+ while (!this.profileField || !this.statusField) {
+ await this.update();
+ }
+ }
+
+ public async update(): Promise