Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Add config v2 example and refine validation error message (#3248)
Browse files Browse the repository at this point in the history
  • Loading branch information
ultmaster authored Jan 4, 2021
1 parent bdb2826 commit 3423117
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 9 deletions.
23 changes: 23 additions & 0 deletions examples/trials/mnist-pytorch/config_v2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
searchSpace:
momentum:
_type: uniform
_value: [0, 1]
hidden_size:
_type: choice
_value: [128, 256, 512, 1024]
batch_size:
_type: choice
_value: [16, 32, 64, 128]
lr:
_type: choice
_value: [0.0001, 0.001, 0.01, 0.1]
trainingService:
platform: local
trialCodeDirectory: .
trialCommand: python3 mnist.py
trialConcurrency: 1
trialGpuNumber: 0
tuner:
name: TPE
classArgs:
optimize_mode: maximize
26 changes: 26 additions & 0 deletions examples/trials/mnist-tfv2/config_v2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
searchSpace:
dropout_rate:
_type: uniform
_value: [0.5, 0.9]
conv_size:
_type: choice
_value: [2, 3, 5, 7]
hidden_size:
_type: choice
_value: [128, 512, 1024]
batch_size:
_type: choice
_value: [16, 32]
learning_rate:
_type: choice
_value: [0.0001, 0.001, 0.01, 0.1]
trainingService:
platform: local
trialCodeDirectory: .
trialCommand: python3 mnist.py
trialConcurrency: 1
trialGpuNumber: 0
tuner:
name: TPE
classArgs:
optimize_mode: maximize
23 changes: 14 additions & 9 deletions nni/tools/nnictl/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .rest_utils import rest_put, rest_post, check_rest_server, check_response
from .url_utils import cluster_metadata_url, experiment_url, get_local_urls
from .config_utils import Config, Experiments
from .common_utils import get_yml_content, get_json_content, print_error, print_normal, \
from .common_utils import get_yml_content, get_json_content, print_error, print_normal, print_warning, \
detect_port, get_user

from .constants import NNICTL_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER
Expand Down Expand Up @@ -592,16 +592,21 @@ def create_experiment(args):
print_error('Please set correct config path!')
exit(1)
experiment_config = get_yml_content(config_path)
try:
config = ExperimentConfig(**experiment_config)
experiment_config = convert.to_v1_yaml(config)
except Exception:
pass

try:
validate_all_content(experiment_config, config_path)
except Exception as e:
print_error(e)
exit(1)
except Exception:
print_warning('Validation with V1 schema failed. Trying to convert from V2 format...')
try:
config = ExperimentConfig(**experiment_config)
experiment_config = convert.to_v1_yaml(config)
except Exception as e:
print_error(f'Conversion from v2 format failed: {repr(e)}')
try:
validate_all_content(experiment_config, config_path)
except Exception as e:
print_error(f'Config validation failed. {repr(e)}')
exit(1)

nni_config.set_config('experimentConfig', experiment_config)
nni_config.set_config('restServerPort', args.port)
Expand Down

0 comments on commit 3423117

Please sign in to comment.