Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Fix nnictl bugs and add new feature (#75)
Browse files Browse the repository at this point in the history
* fix nnictl bug

* fix nnictl create bug

* add experiment status logic

* add more information for nnictl

* fix Evolution Tuner bug

* refactor code

* fix code in updater.py

* fix nnictl --help

* fix classArgs bug

* update check response.status_code logic
  • Loading branch information
SparkSnail authored Sep 19, 2018
1 parent b58666a commit cdee9c3
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 32 deletions.
4 changes: 2 additions & 2 deletions tools/nnicmd/config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
'codeDir': os.path.exists,
'classFileName': str,
'className': str,
'classArgs': {
'optimize_mode': Or('maximize', 'minimize'),
Optional('classArgs'): {
Optional('optimize_mode'): Or('maximize', 'minimize'),
Optional('speed'): int
},
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999),
Expand Down
34 changes: 27 additions & 7 deletions tools/nnicmd/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
from nni_annotation import *
import random
from .launcher_utils import validate_all_content
from .rest_utils import rest_put, rest_post, check_rest_server, check_rest_server_quick
from .rest_utils import rest_put, rest_post, check_rest_server, check_rest_server_quick, check_response
from .url_utils import cluster_metadata_url, experiment_url
from .config_utils import Config
from .common_utils import get_yml_content, get_json_content, print_error, print_normal
from .common_utils import get_yml_content, get_json_content, print_error, print_normal, detect_process
from .constants import EXPERIMENT_SUCCESS_INFO, STDOUT_FULL_PATH, STDERR_FULL_PATH, LOG_DIR, REST_PORT, ERROR_INFO, NORMAL_INFO
from .webui_utils import start_web_ui, check_web_ui

Expand All @@ -40,7 +40,8 @@ def start_rest_server(port, platform, mode, experiment_id=None):
print_normal('Checking experiment...')
nni_config = Config()
rest_port = nni_config.get_config('restServerPort')
if rest_port and check_rest_server_quick(rest_port):
running, _ = check_rest_server_quick(rest_port)
if rest_port and running:
print_error('There is an experiment running, please stop it first...')
print_normal('You can use \'nnictl stop\' command to stop an experiment!')
exit(0)
Expand All @@ -66,7 +67,12 @@ def set_trial_config(experiment_config, port):
value_dict['gpuNum'] = experiment_config['trial']['gpuNum']
request_data['trial_config'] = value_dict
response = rest_put(cluster_metadata_url(port), json.dumps(request_data), 20)
return True if response.status_code == 200 else False
if check_response(response):
return True
else:
with open(STDERR_FULL_PATH, 'a+') as fout:
fout.write(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))
return False

def set_local_config(experiment_config, port):
'''set local configuration'''
Expand All @@ -79,9 +85,11 @@ def set_remote_config(experiment_config, port):
request_data['machine_list'] = experiment_config['machineList']
response = rest_put(cluster_metadata_url(port), json.dumps(request_data), 20)
err_message = ''
if not response or not response.status_code == 200:
if not response or not check_response(response):
if response is not None:
err_message = response.text
with open(STDERR_FULL_PATH, 'a+') as fout:
fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
return False, err_message

#set trial_config
Expand Down Expand Up @@ -117,11 +125,22 @@ def set_experiment(experiment_config, mode, port):
{'key': 'trial_config', 'value': value_dict})

response = rest_post(experiment_url(port), json.dumps(request_data), 20)
return response if response.status_code == 200 else None
if check_response(response):
return response
else:
with open(STDERR_FULL_PATH, 'a+') as fout:
fout.write(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))
return None

def launch_experiment(args, experiment_config, mode, webuiport, experiment_id=None):
'''follow steps to start rest server and start experiment'''
nni_config = Config()
#Check if there is an experiment running
origin_rest_pid = nni_config.get_config('restServerPid')
if origin_rest_pid and detect_process(origin_rest_pid):
print_error('There is an experiment running, please stop it first...')
print_normal('You can use \'nnictl stop\' command to stop an experiment!')
exit(0)
# start rest server
rest_process = start_rest_server(REST_PORT, experiment_config['trainingServicePlatform'], mode, experiment_id)
nni_config.set_config('restServerPid', rest_process.pid)
Expand All @@ -144,7 +163,8 @@ def launch_experiment(args, experiment_config, mode, webuiport, experiment_id=No

# check rest server
print_normal('Checking restful server...')
if check_rest_server(REST_PORT):
running, _ = check_rest_server(REST_PORT)
if running:
print_normal('Restful server start success!')
else:
print_error('Restful server start failed!')
Expand Down
3 changes: 2 additions & 1 deletion tools/nnicmd/launcher_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def parse_tuner_content(experiment_config):

if experiment_config['tuner'].get('builtinTunerName') and experiment_config['tuner'].get('classArgs'):
experiment_config['tuner']['className'] = tuner_class_name_dict.get(experiment_config['tuner']['builtinTunerName'])
experiment_config['tuner']['classArgs']['algorithm_name'] = tuner_algorithm_name_dict.get(experiment_config['tuner']['builtinTunerName'])
if tuner_algorithm_name_dict.get(experiment_config['tuner']['builtinTunerName']):
experiment_config['tuner']['classArgs']['algorithm_name'] = tuner_algorithm_name_dict.get(experiment_config['tuner']['builtinTunerName'])
elif experiment_config['tuner'].get('codeDir') and experiment_config['tuner'].get('classFileName') and experiment_config['tuner'].get('className'):
if not os.path.exists(os.path.join(experiment_config['tuner']['codeDir'], experiment_config['tuner']['classFileName'])):
raise ValueError('Tuner file directory is not valid!')
Expand Down
6 changes: 4 additions & 2 deletions tools/nnicmd/nnictl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
from .nnictl_utils import *

def nni_help_info(*args):
print('please run "nnictl --help" to see nnictl guidance')
print('please run "nnictl {positional argument} --help" to see nnictl guidance')

def parse_args():
'''Definite the arguments users need to follow and input'''
parser = argparse.ArgumentParser(prog='nni ctl', description='use nni control')
parser = argparse.ArgumentParser(prog='nnictl', description='use nnictl command to control nni experiments')
parser.set_defaults(func=nni_help_info)

# create subparsers for args with sub values
Expand Down Expand Up @@ -95,6 +95,8 @@ def parse_args():
parser_experiment_subparsers = parser_experiment.add_subparsers()
parser_experiment_show = parser_experiment_subparsers.add_parser('show', help='show the information of experiment')
parser_experiment_show.set_defaults(func=list_experiment)
parser_experiment_status = parser_experiment_subparsers.add_parser('status', help='show the status of experiment')
parser_experiment_status.set_defaults(func=experiment_status)

#parse config command
parser_config = subparsers.add_parser('config', help='get config information')
Expand Down
35 changes: 25 additions & 10 deletions tools/nnicmd/nnictl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import json
import datetime
from subprocess import call, check_output
from .rest_utils import rest_get, rest_delete, check_rest_server_quick
from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response
from .config_utils import Config
from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url
from .constants import STDERR_FULL_PATH, STDOUT_FULL_PATH
Expand All @@ -47,7 +47,8 @@ def check_rest(args):
'''check if restful server is running'''
nni_config = Config()
rest_port = nni_config.get_config('restServerPort')
if check_rest_server_quick(rest_port):
running, _ = check_rest_server_quick(rest_port)
if not running:
print_normal('Restful server is running...')
else:
print_normal('Restful server is not running...')
Expand All @@ -62,9 +63,10 @@ def stop_experiment(args):
print_normal('Experiment is not running...')
stop_web_ui()
return
if check_rest_server_quick(rest_port):
running, _ = check_rest_server_quick(rest_port)
if running:
response = rest_delete(experiment_url(rest_port), 20)
if not response or response.status_code != 200:
if not response or not check_response(response):
print_error('Stop experiment failed!')
#sleep to wait rest handler done
time.sleep(3)
Expand All @@ -82,9 +84,10 @@ def trial_ls(args):
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
if check_rest_server_quick(rest_port):
running, response = check_rest_server_quick(rest_port)
if running:
response = rest_get(trial_jobs_url(rest_port), 20)
if response and response.status_code == 200:
if response and check_response(response):
content = json.loads(response.text)
for index, value in enumerate(content):
content[index] = convert_time_stamp_to_date(value)
Expand All @@ -102,9 +105,10 @@ def trial_kill(args):
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
if check_rest_server_quick(rest_port):
running, _ = check_rest_server_quick(rest_port)
if running:
response = rest_delete(trial_job_id_url(rest_port, args.trialid), 20)
if response and response.status_code == 200:
if response and check_response(response):
print(response.text)
else:
print_error('Kill trial job failed...')
Expand All @@ -119,16 +123,27 @@ def list_experiment(args):
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
if check_rest_server_quick(rest_port):
running, _ = check_rest_server_quick(rest_port)
if running:
response = rest_get(experiment_url(rest_port), 20)
if response and response.status_code == 200:
if response and check_response(response):
content = convert_time_stamp_to_date(json.loads(response.text))
print(json.dumps(content, indent=4, sort_keys=True, separators=(',', ':')))
else:
print_error('List experiment failed...')
else:
print_error('Restful server is not running...')

def experiment_status(args):
'''Show the status of experiment'''
nni_config = Config()
rest_port = nni_config.get_config('restServerPort')
result, response = check_rest_server_quick(rest_port)
if not result:
print_normal('Restful server is not running...')
else:
print(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))

def get_log_content(file_name, cmds):
'''use cmds to read config content'''
if os.path.exists(file_name):
Expand Down
12 changes: 9 additions & 3 deletions tools/nnicmd/rest_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,22 @@ def check_rest_server(rest_port):
response = rest_get(check_status_url(rest_port), 20)
if response:
if response.status_code == 200:
return True
return True, response
else:
return False
return False, response
else:
time.sleep(3)
return False
return False, response

def check_rest_server_quick(rest_port):
'''Check if restful server is ready, only check once'''
response = rest_get(check_status_url(rest_port), 5)
if response and response.status_code == 200:
return True, response
return False, None

def check_response(response):
'''Check if a response is success according to status_code'''
if response and response.status_code == 200:
return True
return False
9 changes: 5 additions & 4 deletions tools/nnicmd/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import json
import os
from .rest_utils import rest_put, rest_get, check_rest_server_quick
from .rest_utils import rest_put, rest_get, check_rest_server_quick, check_response
from .url_utils import experiment_url
from .config_utils import Config
from .common_utils import get_json_content
Expand Down Expand Up @@ -56,13 +56,14 @@ def update_experiment_profile(key, value):
'''call restful server to update experiment profile'''
nni_config = Config()
rest_port = nni_config.get_config('restServerPort')
if check_rest_server_quick(rest_port):
running, _ = check_rest_server_quick(rest_port)
if running:
response = rest_get(experiment_url(rest_port), 20)
if response and response.status_code == 200:
if response and check_response(response):
experiment_profile = json.loads(response.text)
experiment_profile['params'][key] = value
response = rest_put(experiment_url(rest_port)+get_query_type(key), json.dumps(experiment_profile), 20)
if response and response.status_code == 200:
if response and check_response(response):
return response
else:
print('ERROR: restful server is not running...')
Expand Down
8 changes: 5 additions & 3 deletions tools/nnicmd/webui_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import os
import psutil
from socket import AddressFamily
from subprocess import Popen, PIPE
from .rest_utils import rest_get
from subprocess import Popen, PIPE, call
from .rest_utils import rest_get, check_response
from .config_utils import Config
from .common_utils import print_error, print_normal
from .constants import STDOUT_FULL_PATH, STDERR_FULL_PATH
Expand Down Expand Up @@ -71,6 +71,8 @@ def stop_web_ui():
child_process.kill()
if parent_process.is_running():
parent_process.kill()
cmds = ['pkill', '-P', str(webuiPid)]
call(cmds)
return True
except Exception as e:
print_error(e)
Expand All @@ -84,6 +86,6 @@ def check_web_ui():
return False
for url in url_list:
response = rest_get(url, 3)
if response and response.status_code == 200:
if response and check_response(response):
return True
return False

0 comments on commit cdee9c3

Please sign in to comment.