From d7a62f66b7fa13898ad5d7e48f42b9dd8602bb14 Mon Sep 17 00:00:00 2001 From: Shinai Yang Date: Thu, 31 Oct 2019 17:03:03 +0800 Subject: [PATCH 1/4] check pylint for nni_cmd --- tools/nni_cmd/command_utils.py | 2 +- tools/nni_cmd/common_utils.py | 10 +-- tools/nni_cmd/config_schema.py | 104 ++++++++++++++-------------- tools/nni_cmd/config_utils.py | 29 ++++---- tools/nni_cmd/constants.py | 5 +- tools/nni_cmd/launcher.py | 46 ++++++------ tools/nni_cmd/launcher_utils.py | 20 +++--- tools/nni_cmd/nnictl.py | 16 +++-- tools/nni_cmd/nnictl_utils.py | 66 +++++++++--------- tools/nni_cmd/package_management.py | 6 +- tools/nni_cmd/ssh_utils.py | 6 +- tools/nni_cmd/tensorboard_utils.py | 25 +++---- tools/nni_cmd/updater.py | 2 +- tools/nni_cmd/url_utils.py | 6 +- 14 files changed, 171 insertions(+), 172 deletions(-) diff --git a/tools/nni_cmd/command_utils.py b/tools/nni_cmd/command_utils.py index a3bcb81965..cf13f63eae 100644 --- a/tools/nni_cmd/command_utils.py +++ b/tools/nni_cmd/command_utils.py @@ -3,7 +3,7 @@ import os import signal import psutil -from .common_utils import print_error, print_normal, print_warning +from .common_utils import print_error def check_output_command(file_path, head=None, tail=None): diff --git a/tools/nni_cmd/common_utils.py b/tools/nni_cmd/common_utils.py index 3a5e909ca2..af0fe3efa6 100644 --- a/tools/nni_cmd/common_utils.py +++ b/tools/nni_cmd/common_utils.py @@ -21,10 +21,10 @@ import os import sys import json -import ruamel.yaml as yaml -import psutil import socket from pathlib import Path +import ruamel.yaml as yaml +import psutil from .constants import ERROR_INFO, NORMAL_INFO, WARNING_INFO, COLOR_RED_FORMAT, COLOR_YELLOW_FORMAT def get_yml_content(file_path): @@ -34,6 +34,7 @@ def get_yml_content(file_path): return yaml.load(file, Loader=yaml.Loader) except yaml.scanner.ScannerError as err: print_error('yaml file format error!') + print_error(err) exit(1) except Exception as exception: print_error(exception) @@ -46,6 +47,7 @@ def get_json_content(file_path): return json.load(file) except TypeError as err: print_error('json file format error!') + print_error(err) return None def print_error(content): @@ -70,7 +72,7 @@ def detect_process(pid): def detect_port(port): '''Detect if the port is used''' - socket_test = socket.socket(socket.AF_INET,socket.SOCK_STREAM) + socket_test = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: socket_test.connect(('127.0.0.1', int(port))) socket_test.close() @@ -79,7 +81,7 @@ def detect_port(port): return False def get_user(): - if sys.platform =='win32': + if sys.platform == 'win32': return os.environ['USERNAME'] else: return os.environ['USER'] diff --git a/tools/nni_cmd/config_schema.py b/tools/nni_cmd/config_schema.py index da943564fb..dded8d1e95 100644 --- a/tools/nni_cmd/config_schema.py +++ b/tools/nni_cmd/config_schema.py @@ -19,13 +19,13 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. import os -from schema import Schema, And, Use, Optional, Regex, Or +from schema import Schema, And, Optional, Regex, Or from .constants import SCHEMA_TYPE_ERROR, SCHEMA_RANGE_ERROR, SCHEMA_PATH_ERROR -def setType(key, type): +def setType(key, valueType): '''check key type''' - return And(type, error=SCHEMA_TYPE_ERROR % (key, type.__name__)) + return And(valueType, error=SCHEMA_TYPE_ERROR % (key, valueType.__name__)) def setChoice(key, *args): '''check choice''' @@ -47,7 +47,7 @@ def setPathCheck(key): 'experimentName': setType('experimentName', str), Optional('description'): setType('description', str), 'trialConcurrency': setNumberRange('trialConcurrency', int, 1, 99999), - Optional('maxExecDuration'): And(Regex(r'^[1-9][0-9]*[s|m|h|d]$',error='ERROR: maxExecDuration format is [digit]{s,m,h,d}')), + Optional('maxExecDuration'): And(Regex(r'^[1-9][0-9]*[s|m|h|d]$', error='ERROR: maxExecDuration format is [digit]{s,m,h,d}')), Optional('maxTrialNum'): setNumberRange('maxTrialNum', int, 1, 99999), 'trainingServicePlatform': setChoice('trainingServicePlatform', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller'), Optional('searchSpacePath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'searchSpacePath'), @@ -106,7 +106,7 @@ def setPathCheck(key): 'builtinTunerName': 'NetworkMorphism', Optional('classArgs'): { Optional('optimize_mode'): setChoice('optimize_mode', 'maximize', 'minimize'), - Optional('task'): setChoice('task', 'cv','nlp','common'), + Optional('task'): setChoice('task', 'cv', 'nlp', 'common'), Optional('input_width'): setType('input_width', int), Optional('input_channel'): setType('input_channel', int), Optional('n_output_node'): setType('n_output_node', int), @@ -139,7 +139,7 @@ def setPathCheck(key): Optional('selection_num_warm_up'): setType('selection_num_warm_up', int), Optional('selection_num_starting_points'): setType('selection_num_starting_points', int), }, - Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool), + Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool), Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'), }, 'PPOTuner': { @@ -232,35 +232,35 @@ def setPathCheck(key): } common_trial_schema = { -'trial':{ - 'command': setType('command', str), - 'codeDir': setPathCheck('codeDir'), - Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999), - Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode') + 'trial':{ + 'command': setType('command', str), + 'codeDir': setPathCheck('codeDir'), + Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999), + Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode') } } pai_trial_schema = { -'trial':{ - 'command': setType('command', str), - 'codeDir': setPathCheck('codeDir'), - 'gpuNum': setNumberRange('gpuNum', int, 0, 99999), - 'cpuNum': setNumberRange('cpuNum', int, 0, 99999), - 'memoryMB': setType('memoryMB', int), - 'image': setType('image', str), - Optional('authFile'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'authFile'), - Optional('shmMB'): setType('shmMB', int), - Optional('dataDir'): And(Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),\ - error='ERROR: dataDir format error, dataDir format is hdfs://xxx.xxx.xxx.xxx:xxx'), - Optional('outputDir'): And(Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),\ - error='ERROR: outputDir format error, outputDir format is hdfs://xxx.xxx.xxx.xxx:xxx'), - Optional('virtualCluster'): setType('virtualCluster', str), - Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'), - Optional('portList'): [{ - "label": setType('label', str), - "beginAt": setType('beginAt', int), - "portNumber": setType('portNumber', int) - }] + 'trial':{ + 'command': setType('command', str), + 'codeDir': setPathCheck('codeDir'), + 'gpuNum': setNumberRange('gpuNum', int, 0, 99999), + 'cpuNum': setNumberRange('cpuNum', int, 0, 99999), + 'memoryMB': setType('memoryMB', int), + 'image': setType('image', str), + Optional('authFile'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'authFile'), + Optional('shmMB'): setType('shmMB', int), + Optional('dataDir'): And(Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),\ + error='ERROR: dataDir format error, dataDir format is hdfs://xxx.xxx.xxx.xxx:xxx'), + Optional('outputDir'): And(Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),\ + error='ERROR: outputDir format error, outputDir format is hdfs://xxx.xxx.xxx.xxx:xxx'), + Optional('virtualCluster'): setType('virtualCluster', str), + Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'), + Optional('portList'): [{ + "label": setType('label', str), + "beginAt": setType('beginAt', int), + "portNumber": setType('portNumber', int) + }] } } @@ -273,7 +273,7 @@ def setPathCheck(key): } kubeflow_trial_schema = { -'trial':{ + 'trial':{ 'codeDir': setPathCheck('codeDir'), Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'), Optional('ps'): { @@ -315,7 +315,7 @@ def setPathCheck(key): 'server': setType('server', str), 'path': setType('path', str) } - },{ + }, { 'operator': setChoice('operator', 'tf-operator', 'pytorch-operator'), 'apiVersion': setType('apiVersion', str), Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'), @@ -363,7 +363,7 @@ def setPathCheck(key): 'server': setType('server', str), 'path': setType('path', str) } - },{ + }, { Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'), Optional('serviceAccountName'): setType('serviceAccountName', str), 'keyVault': { @@ -383,24 +383,24 @@ def setPathCheck(key): } machine_list_schema = { -Optional('machineList'):[Or({ - 'ip': setType('ip', str), - Optional('port'): setNumberRange('port', int, 1, 65535), - 'username': setType('username', str), - 'passwd': setType('passwd', str), - Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'), - Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int), - Optional('useActiveGpu'): setType('useActiveGpu', bool) - },{ - 'ip': setType('ip', str), - Optional('port'): setNumberRange('port', int, 1, 65535), - 'username': setType('username', str), - 'sshKeyPath': setPathCheck('sshKeyPath'), - Optional('passphrase'): setType('passphrase', str), - Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'), - Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int), - Optional('useActiveGpu'): setType('useActiveGpu', bool) -})] + Optional('machineList'):[Or({ + 'ip': setType('ip', str), + Optional('port'): setNumberRange('port', int, 1, 65535), + 'username': setType('username', str), + 'passwd': setType('passwd', str), + Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'), + Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int), + Optional('useActiveGpu'): setType('useActiveGpu', bool) + }, { + 'ip': setType('ip', str), + Optional('port'): setNumberRange('port', int, 1, 65535), + 'username': setType('username', str), + 'sshKeyPath': setPathCheck('sshKeyPath'), + Optional('passphrase'): setType('passphrase', str), + Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'), + Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int), + Optional('useActiveGpu'): setType('useActiveGpu', bool) + })] } LOCAL_CONFIG_SCHEMA = Schema({**common_schema, **common_trial_schema}) diff --git a/tools/nni_cmd/config_utils.py b/tools/nni_cmd/config_utils.py index 6b2b8a0cc0..c5a36b374d 100644 --- a/tools/nni_cmd/config_utils.py +++ b/tools/nni_cmd/config_utils.py @@ -21,7 +21,6 @@ import os import json -import shutil from .constants import NNICTL_HOME_DIR class Config: @@ -73,29 +72,29 @@ def __init__(self): self.experiment_file = os.path.join(NNICTL_HOME_DIR, '.experiment') self.experiments = self.read_file() - def add_experiment(self, id, port, time, file_name, platform): + def add_experiment(self, expId, port, time, file_name, platform): '''set {key:value} paris to self.experiment''' - self.experiments[id] = {} - self.experiments[id]['port'] = port - self.experiments[id]['startTime'] = time - self.experiments[id]['endTime'] = 'N/A' - self.experiments[id]['status'] = 'INITIALIZED' - self.experiments[id]['fileName'] = file_name - self.experiments[id]['platform'] = platform + self.experiments[expId] = {} + self.experiments[expId]['port'] = port + self.experiments[expId]['startTime'] = time + self.experiments[expId]['endTime'] = 'N/A' + self.experiments[expId]['status'] = 'INITIALIZED' + self.experiments[expId]['fileName'] = file_name + self.experiments[expId]['platform'] = platform self.write_file() - def update_experiment(self, id, key, value): + def update_experiment(self, expId, key, value): '''Update experiment''' if id not in self.experiments: return False - self.experiments[id][key] = value + self.experiments[expId][key] = value self.write_file() return True - def remove_experiment(self, id): + def remove_experiment(self, expId): '''remove an experiment by id''' if id in self.experiments: - self.experiments.pop(id) + self.experiments.pop(expId) self.write_file() def get_all_experiments(self): @@ -109,7 +108,7 @@ def write_file(self): json.dump(self.experiments, file) except IOError as error: print('Error:', error) - return + return '' def read_file(self): '''load config from local file''' @@ -119,4 +118,4 @@ def read_file(self): return json.load(file) except ValueError: return {} - return {} + return {} diff --git a/tools/nni_cmd/constants.py b/tools/nni_cmd/constants.py index d22a509c46..0777d2db98 100644 --- a/tools/nni_cmd/constants.py +++ b/tools/nni_cmd/constants.py @@ -21,7 +21,7 @@ import os from colorama import Fore -NNICTL_HOME_DIR = os.path.join(os.path.expanduser('~'), '.local', 'nnictl') +NNICTL_HOME_DIR = os.path.join(os.path.expanduser('~'), '.local', 'nnictl') ERROR_INFO = 'ERROR: %s' @@ -58,7 +58,8 @@ '-----------------------------------------------------------------------\n' EXPERIMENT_START_FAILED_INFO = 'There is an experiment running in the port %d, please stop it first or set another port!\n' \ - 'You could use \'nnictl stop --port [PORT]\' command to stop an experiment!\nOr you could use \'nnictl create --config [CONFIG_PATH] --port [PORT]\' to set port!\n' + 'You could use \'nnictl stop --port [PORT]\' command to stop an experiment!\nOr you could ' \ + 'use \'nnictl create --config [CONFIG_PATH] --port [PORT]\' to set port!\n' EXPERIMENT_INFORMATION_FORMAT = '----------------------------------------------------------------------------------------\n' \ ' Experiment information\n' \ diff --git a/tools/nni_cmd/launcher.py b/tools/nni_cmd/launcher.py index e2fac2cb42..f99f8dfe43 100644 --- a/tools/nni_cmd/launcher.py +++ b/tools/nni_cmd/launcher.py @@ -22,22 +22,21 @@ import json import os import sys -import shutil import string -from subprocess import Popen, PIPE, call, check_output, check_call, CalledProcessError +import random +import site +import time import tempfile +from subprocess import Popen, check_call, CalledProcessError +from nni_annotation import expand_annotations, generate_search_space from nni.constants import ModuleName, AdvisorModuleName -from nni_annotation import * from .launcher_utils import validate_all_content -from .rest_utils import rest_put, rest_post, check_rest_server, check_rest_server_quick, check_response +from .rest_utils import rest_put, rest_post, check_rest_server, check_response from .url_utils import cluster_metadata_url, experiment_url, get_local_urls from .config_utils import Config, Experiments -from .common_utils import get_yml_content, get_json_content, print_error, print_normal, print_warning, detect_process, detect_port, get_user, get_python_dir -from .constants import * -import random -import site -import time -from pathlib import Path +from .common_utils import get_yml_content, get_json_content, print_error, print_normal, \ + detect_port, get_user, get_python_dir +from .constants import NNICTL_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER, PACKAGE_REQUIREMENTS from .command_utils import check_output_command, kill_command from .nnictl_utils import update_experiment @@ -83,7 +82,8 @@ def _generate_installation_path(sitepackages_path): python_dir = os.getenv('VIRTUAL_ENV') else: python_sitepackage = site.getsitepackages()[0] - # If system-wide python is used, we will give priority to using `local sitepackage`--"usersitepackages()" given that nni exists there + # If system-wide python is used, we will give priority to using `local sitepackage`--"usersitepackages()" given + # that nni exists there if python_sitepackage.startswith('/usr') or python_sitepackage.startswith('/Library'): python_dir = try_installation_path_sequentially(site.getusersitepackages(), site.getsitepackages()[0]) else: @@ -98,7 +98,6 @@ def _generate_installation_path(sitepackages_path): def start_rest_server(port, platform, mode, config_file_name, experiment_id=None, log_dir=None, log_level=None): '''Run nni manager process''' - nni_config = Config(config_file_name) if detect_port(port): print_error('Port %s is used by another process, please reset the port!\n' \ 'You could use \'nnictl create --help\' to get help information' % port) @@ -114,7 +113,7 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None entry_dir = get_nni_installation_path() entry_file = os.path.join(entry_dir, 'main.js') - + node_command = 'node' if sys.platform == 'win32': node_command = os.path.join(entry_dir[:-3], 'Scripts', 'node.exe') @@ -132,7 +131,7 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None cmds += ['--experiment_id', experiment_id] stdout_full_path, stderr_full_path = get_log_path(config_file_name) with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file: - time_now = time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())) + time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) #add time information in the header of log files log_header = LOG_HEADER % str(time_now) stdout_file.write(log_header) @@ -212,7 +211,7 @@ def setNNIManagerIp(experiment_config, port, config_file_name): if experiment_config.get('nniManagerIp') is None: return True, None ip_config_dict = dict() - ip_config_dict['nni_manager_ip'] = { 'nniManagerIp' : experiment_config['nniManagerIp'] } + ip_config_dict['nni_manager_ip'] = {'nniManagerIp': experiment_config['nniManagerIp']} response = rest_put(cluster_metadata_url(port), json.dumps(ip_config_dict), REST_TIME_OUT) err_message = None if not response or not response.status_code == 200: @@ -403,11 +402,12 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen stdout_full_path, stderr_full_path = get_log_path(config_file_name) with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file: check_call([sys.executable, '-c', 'import %s'%(module_name)], stdout=stdout_file, stderr=stderr_file) - except CalledProcessError as e: + except CalledProcessError: print_error('some errors happen when import package %s.' %(package_name)) print_log_content(config_file_name) if package_name in PACKAGE_REQUIREMENTS: - print_error('If %s is not installed, it should be installed through \'nnictl package install --name %s\''%(package_name, package_name)) + print_error('If %s is not installed, it should be installed through '\ + '\'nnictl package install --name %s\''%(package_name, package_name)) exit(1) log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else None log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None @@ -416,7 +416,8 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen if log_level not in ['trace', 'debug'] and (args.debug or experiment_config.get('debug') is True): log_level = 'debug' # start rest server - rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], mode, config_file_name, experiment_id, log_dir, log_level) + rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \ + mode, config_file_name, experiment_id, log_dir, log_level) nni_config.set_config('restServerPid', rest_process.pid) # Deal with annotation if experiment_config.get('useAnnotation'): @@ -450,8 +451,9 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen exit(1) if mode != 'view': # set platform configuration - set_platform_config(experiment_config['trainingServicePlatform'], experiment_config, args.port, config_file_name, rest_process) - + set_platform_config(experiment_config['trainingServicePlatform'], experiment_config, args.port,\ + config_file_name, rest_process) + # start a new experiment print_normal('Starting experiment...') # set debug configuration @@ -478,7 +480,8 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen #save experiment information nnictl_experiment_config = Experiments() - nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time, config_file_name, experiment_config['trainingServicePlatform']) + nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time, config_file_name,\ + experiment_config['trainingServicePlatform']) print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list))) @@ -503,7 +506,6 @@ def manage_stopped_experiment(args, mode): experiment_config = Experiments() experiment_dict = experiment_config.get_all_experiments() experiment_id = None - experiment_endTime = None #find the latest stopped experiment if not args.id: print_error('Please set experiment id! \nYou could use \'nnictl {0} {id}\' to {0} a stopped experiment!\n' \ diff --git a/tools/nni_cmd/launcher_utils.py b/tools/nni_cmd/launcher_utils.py index da6a668064..f6c849abab 100644 --- a/tools/nni_cmd/launcher_utils.py +++ b/tools/nni_cmd/launcher_utils.py @@ -20,11 +20,11 @@ import os import json -from .config_schema import LOCAL_CONFIG_SCHEMA, REMOTE_CONFIG_SCHEMA, PAI_CONFIG_SCHEMA, KUBEFLOW_CONFIG_SCHEMA, FRAMEWORKCONTROLLER_CONFIG_SCHEMA, \ -tuner_schema_dict, advisor_schema_dict, assessor_schema_dict -from schema import SchemaMissingKeyError, SchemaForbiddenKeyError, SchemaUnexpectedTypeError, SchemaWrongKeyError, SchemaError -from .common_utils import get_json_content, print_error, print_warning, print_normal -from schema import Schema, And, Use, Optional, Regex, Or +from schema import SchemaError +from schema import Schema +from .config_schema import LOCAL_CONFIG_SCHEMA, REMOTE_CONFIG_SCHEMA, PAI_CONFIG_SCHEMA, KUBEFLOW_CONFIG_SCHEMA,\ + FRAMEWORKCONTROLLER_CONFIG_SCHEMA, tuner_schema_dict, advisor_schema_dict, assessor_schema_dict +from .common_utils import print_error, print_warning, print_normal def expand_path(experiment_config, key): '''Change '~' to user home directory''' @@ -164,11 +164,11 @@ def validate_common_content(experiment_config): print_error('Please set correct trainingServicePlatform!') exit(1) schema_dict = { - 'local': LOCAL_CONFIG_SCHEMA, - 'remote': REMOTE_CONFIG_SCHEMA, - 'pai': PAI_CONFIG_SCHEMA, - 'kubeflow': KUBEFLOW_CONFIG_SCHEMA, - 'frameworkcontroller': FRAMEWORKCONTROLLER_CONFIG_SCHEMA + 'local': LOCAL_CONFIG_SCHEMA, + 'remote': REMOTE_CONFIG_SCHEMA, + 'pai': PAI_CONFIG_SCHEMA, + 'kubeflow': KUBEFLOW_CONFIG_SCHEMA, + 'frameworkcontroller': FRAMEWORKCONTROLLER_CONFIG_SCHEMA } separate_schema_dict = { 'tuner': tuner_schema_dict, diff --git a/tools/nni_cmd/nnictl.py b/tools/nni_cmd/nnictl.py index 8da30fdfb7..88ee311423 100644 --- a/tools/nni_cmd/nnictl.py +++ b/tools/nni_cmd/nnictl.py @@ -20,14 +20,18 @@ import argparse +import os import pkg_resources +from colorama import init +from .common_utils import print_error from .launcher import create_experiment, resume_experiment, view_experiment from .updater import update_searchspace, update_concurrency, update_duration, update_trialnum, import_data -from .nnictl_utils import * -from .package_management import * -from .constants import * -from .tensorboard_utils import * -from colorama import init +from .nnictl_utils import stop_experiment, trial_ls, trial_kill, list_experiment, experiment_status,\ + log_trial, experiment_clean, platform_clean, experiment_list, \ + monitor_experiment, export_trials_data, trial_codegen, webui_url, get_config, log_stdout, log_stderr +from .package_management import package_install, package_show +from .constants import DEFAULT_REST_PORT +from .tensorboard_utils import start_tensorboard, stop_tensorboard init(autoreset=True) if os.environ.get('COVERAGE_PROCESS_START'): @@ -38,7 +42,7 @@ def nni_info(*args): if args[0].version: try: print(pkg_resources.get_distribution('nni').version) - except pkg_resources.ResolutionError as err: + except pkg_resources.ResolutionError: print_error('Get version failed, please use `pip3 list | grep nni` to check nni version!') else: print('please run "nnictl {positional argument} --help" to see nnictl guidance') diff --git a/tools/nni_cmd/nnictl_utils.py b/tools/nni_cmd/nnictl_utils.py index b6fada56e8..4cadce182d 100644 --- a/tools/nni_cmd/nnictl_utils.py +++ b/tools/nni_cmd/nnictl_utils.py @@ -20,15 +20,13 @@ import csv import os -import psutil import json -from datetime import datetime, timezone import time import re -from pathlib import Path -from pyhdfs import HdfsClient, HdfsFileNotFoundException import shutil -from subprocess import call, check_output +from datetime import datetime, timezone +from pathlib import Path +from pyhdfs import HdfsClient from nni_annotation import expand_annotations from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, export_data_url @@ -102,7 +100,8 @@ def check_experiment_id(args, update=True): experiment_information = "" for key in running_experiment_list: experiment_information += (EXPERIMENT_DETAIL_FORMAT % (key, experiment_dict[key]['status'], \ - experiment_dict[key]['port'], experiment_dict[key].get('platform'), experiment_dict[key]['startTime'], experiment_dict[key]['endTime'])) + experiment_dict[key]['port'], experiment_dict[key].get('platform'), experiment_dict[key]['startTime'],\ + experiment_dict[key]['endTime'])) print(EXPERIMENT_INFORMATION_FORMAT % experiment_information) exit(1) elif not running_experiment_list: @@ -157,23 +156,24 @@ def parse_ids(args): experiment_information = "" for key in running_experiment_list: experiment_information += (EXPERIMENT_DETAIL_FORMAT % (key, experiment_dict[key]['status'], \ - experiment_dict[key]['port'], experiment_dict[key].get('platform'), experiment_dict[key]['startTime'], experiment_dict[key]['endTime'])) + experiment_dict[key]['port'], experiment_dict[key].get('platform'), experiment_dict[key]['startTime'], \ + experiment_dict[key]['endTime'])) print(EXPERIMENT_INFORMATION_FORMAT % experiment_information) exit(1) else: result_list = running_experiment_list elif args.id.endswith('*'): - for id in running_experiment_list: - if id.startswith(args.id[:-1]): - result_list.append(id) + for expId in running_experiment_list: + if expId.startswith(args.id[:-1]): + result_list.append(expId) elif args.id in running_experiment_list: result_list.append(args.id) else: - for id in running_experiment_list: - if id.startswith(args.id): - result_list.append(id) + for expId in running_experiment_list: + if expId.startswith(args.id): + result_list.append(expId) if len(result_list) > 1: - print_error(args.id + ' is ambiguous, please choose ' + ' '.join(result_list) ) + print_error(args.id + ' is ambiguous, please choose ' + ' '.join(result_list)) return None if not result_list and (args.id or args.port): print_error('There are no experiments matched, please set correct experiment id or restful server port') @@ -235,7 +235,6 @@ def stop_experiment(args): for experiment_id in experiment_id_list: print_normal('Stoping experiment %s' % experiment_id) nni_config = Config(experiment_dict[experiment_id]['fileName']) - rest_port = nni_config.get_config('restServerPort') rest_pid = nni_config.get_config('restServerPid') if rest_pid: kill_command(rest_pid) @@ -249,7 +248,7 @@ def stop_experiment(args): nni_config.set_config('tensorboardPidList', []) print_normal('Stop experiment success.') experiment_config.update_experiment(experiment_id, 'status', 'STOPPED') - time_now = time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())) + time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) experiment_config.update_experiment(experiment_id, 'endTime', str(time_now)) def trial_ls(args): @@ -401,9 +400,9 @@ def local_clean(directory): print_normal('removing folder {0}'.format(directory)) try: shutil.rmtree(directory) - except FileNotFoundError as err: + except FileNotFoundError: print_error('{0} does not exist.'.format(directory)) - + def remote_clean(machine_list, experiment_id=None): '''clean up remote data''' for machine in machine_list: @@ -418,7 +417,7 @@ def remote_clean(machine_list, experiment_id=None): sftp = create_ssh_sftp_client(host, port, userName, passwd) print_normal('removing folder {0}'.format(host + ':' + str(port) + remote_dir)) remove_remote_directory(sftp, remote_dir) - + def hdfs_clean(host, user_name, output_dir, experiment_id=None): '''clean up hdfs data''' hdfs_client = HdfsClient(hosts='{0}:80'.format(host), user_name=user_name, webhdfs_path='/webhdfs/api/v1', timeout=5) @@ -475,7 +474,7 @@ def experiment_clean(args): machine_list = nni_config.get_config('experimentConfig').get('machineList') remote_clean(machine_list, experiment_id) elif platform == 'pai': - host = nni_config.get_config('experimentConfig').get('paiConfig').get('host') + host = nni_config.get_config('experimentConfig').get('paiConfig').get('host') user_name = nni_config.get_config('experimentConfig').get('paiConfig').get('userName') output_dir = nni_config.get_config('experimentConfig').get('trial').get('outputDir') hdfs_clean(host, user_name, output_dir, experiment_id) @@ -492,7 +491,7 @@ def experiment_clean(args): experiment_config = Experiments() print_normal('removing metadata of experiment {0}'.format(experiment_id)) experiment_config.remove_experiment(experiment_id) - print_normal('Done.') + print_normal('Done.') def get_platform_dir(config_content): '''get the dir list to be deleted''' @@ -505,8 +504,7 @@ def get_platform_dir(config_content): port = machine.get('port') dir_list.append(host + ':' + str(port) + '/tmp/nni') elif platform == 'pai': - pai_config = config_content.get('paiConfig') - host = config_content.get('paiConfig').get('host') + host = config_content.get('paiConfig').get('host') user_name = config_content.get('paiConfig').get('userName') output_dir = config_content.get('trial').get('outputDir') dir_list.append('server: {0}, path: {1}/nni'.format(host, user_name)) @@ -529,17 +527,15 @@ def platform_clean(args): print_normal('platform {0} not supported.'.format(platform)) exit(0) update_experiment() - experiment_config = Experiments() - experiment_dict = experiment_config.get_all_experiments() - id_list = list(experiment_dict.keys()) dir_list = get_platform_dir(config_content) if not dir_list: print_normal('No folder of NNI caches is found.') exit(1) while True: - print_normal('This command will remove below folders of NNI caches. If other users are using experiments on below hosts, it will be broken.') - for dir in dir_list: - print(' ' + dir) + print_normal('This command will remove below folders of NNI caches. If other users are using experiments' \ + ' on below hosts, it will be broken.') + for value in dir_list: + print(' ' + value) inputs = input('INFO: do you want to continue?[y/N]:') if not inputs.lower() or inputs.lower() in ['n', 'no']: exit(0) @@ -549,11 +545,9 @@ def platform_clean(args): break if platform == 'remote': machine_list = config_content.get('machineList') - for machine in machine_list: - remote_clean(machine_list, None) + remote_clean(machine_list, None) elif platform == 'pai': - pai_config = config_content.get('paiConfig') - host = config_content.get('paiConfig').get('host') + host = config_content.get('paiConfig').get('host') user_name = config_content.get('paiConfig').get('userName') output_dir = config_content.get('trial').get('outputDir') hdfs_clean(host, user_name, output_dir, None) @@ -618,7 +612,8 @@ def show_experiment_info(): return for key in experiment_id_list: print(EXPERIMENT_MONITOR_INFO % (key, experiment_dict[key]['status'], experiment_dict[key]['port'], \ - experiment_dict[key].get('platform'), experiment_dict[key]['startTime'], get_time_interval(experiment_dict[key]['startTime'], experiment_dict[key]['endTime']))) + experiment_dict[key].get('platform'), experiment_dict[key]['startTime'], \ + get_time_interval(experiment_dict[key]['startTime'], experiment_dict[key]['endTime']))) print(TRIAL_MONITOR_HEAD) running, response = check_rest_server_quick(experiment_dict[key]['port']) if running: @@ -627,7 +622,8 @@ def show_experiment_info(): content = json.loads(response.text) for index, value in enumerate(content): content[index] = convert_time_stamp_to_date(value) - print(TRIAL_MONITOR_CONTENT % (content[index].get('id'), content[index].get('startTime'), content[index].get('endTime'), content[index].get('status'))) + print(TRIAL_MONITOR_CONTENT % (content[index].get('id'), content[index].get('startTime'), \ + content[index].get('endTime'), content[index].get('status'))) print(TRIAL_MONITOR_TAIL) def monitor_experiment(args): diff --git a/tools/nni_cmd/package_management.py b/tools/nni_cmd/package_management.py index de8dbe62ec..32ed79496d 100644 --- a/tools/nni_cmd/package_management.py +++ b/tools/nni_cmd/package_management.py @@ -18,12 +18,10 @@ # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -import nni import os -import sys -from subprocess import call +import nni from .constants import PACKAGE_REQUIREMENTS -from .common_utils import print_normal, print_error +from .common_utils import print_error from .command_utils import install_requirements_command def process_install(package_name): diff --git a/tools/nni_cmd/ssh_utils.py b/tools/nni_cmd/ssh_utils.py index da707dac48..7453830323 100644 --- a/tools/nni_cmd/ssh_utils.py +++ b/tools/nni_cmd/ssh_utils.py @@ -20,7 +20,6 @@ import os from .common_utils import print_error -from subprocess import call from .command_utils import install_package_command def check_environment(): @@ -29,6 +28,8 @@ def check_environment(): import paramiko except: install_package_command('paramiko') + import paramiko + return paramiko def copy_remote_directory_to_local(sftp, remote_path, local_path): '''copy remote directory to local machine''' @@ -49,8 +50,7 @@ def copy_remote_directory_to_local(sftp, remote_path, local_path): def create_ssh_sftp_client(host_ip, port, username, password): '''create ssh client''' try: - check_environment() - import paramiko + paramiko = check_environment() conn = paramiko.Transport(host_ip, port) conn.connect(username=username, password=password) sftp = paramiko.SFTPClient.from_transport(conn) diff --git a/tools/nni_cmd/tensorboard_utils.py b/tools/nni_cmd/tensorboard_utils.py index b4578c34b0..9646b4de0e 100644 --- a/tools/nni_cmd/tensorboard_utils.py +++ b/tools/nni_cmd/tensorboard_utils.py @@ -19,21 +19,17 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. import os -import psutil import json -import datetime -import time -from subprocess import call, check_output, Popen, PIPE -from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response -from .config_utils import Config, Experiments -from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, get_local_urls -from .constants import NNICTL_HOME_DIR, EXPERIMENT_INFORMATION_FORMAT, EXPERIMENT_DETAIL_FORMAT, COLOR_GREEN_FORMAT -import time -from .common_utils import print_normal, print_error, print_warning, detect_process, detect_port -from .nnictl_utils import * import re -from .ssh_utils import create_ssh_sftp_client, copy_remote_directory_to_local import tempfile +from subprocess import call, Popen +from .rest_utils import rest_get, check_rest_server_quick, check_response +from .config_utils import Config, Experiments +from .url_utils import trial_jobs_url, get_local_urls +from .constants import COLOR_GREEN_FORMAT, REST_TIME_OUT +from .common_utils import print_normal, print_error, detect_process, detect_port +from .nnictl_utils import check_experiment_id, check_experiment_id +from .ssh_utils import create_ssh_sftp_client, copy_remote_directory_to_local def parse_log_path(args, trial_content): '''parse log path''' @@ -43,7 +39,7 @@ def parse_log_path(args, trial_content): if args.trial_id and args.trial_id != 'all' and trial.get('id') != args.trial_id: continue pattern = r'(?P.+)://(?P.+):(?P.*)' - match = re.search(pattern,trial['logPath']) + match = re.search(pattern, trial['logPath']) if match: path_list.append(match.group('path')) host_list.append(match.group('host')) @@ -94,7 +90,8 @@ def start_tensorboard_process(args, nni_config, path_list, temp_nni_path): if detect_port(args.port): print_error('Port %s is used by another process, please reset port!' % str(args.port)) exit(1) - with open(os.path.join(temp_nni_path, 'tensorboard_stdout'), 'a+') as stdout_file, open(os.path.join(temp_nni_path, 'tensorboard_stderr'), 'a+') as stderr_file: + with open(os.path.join(temp_nni_path, 'tensorboard_stdout'), 'a+') as stdout_file, \ + open(os.path.join(temp_nni_path, 'tensorboard_stderr'), 'a+') as stderr_file: cmds = ['tensorboard', '--logdir', format_tensorboard_log_path(path_list), '--port', str(args.port)] tensorboard_process = Popen(cmds, stdout=stdout_file, stderr=stderr_file) url_list = get_local_urls(args.port) diff --git a/tools/nni_cmd/updater.py b/tools/nni_cmd/updater.py index 9258d73f0a..07ae6123cb 100644 --- a/tools/nni_cmd/updater.py +++ b/tools/nni_cmd/updater.py @@ -25,7 +25,7 @@ from .url_utils import experiment_url, import_data_url from .config_utils import Config from .common_utils import get_json_content, print_normal, print_error, print_warning -from .nnictl_utils import check_experiment_id, get_experiment_port, get_config_filename +from .nnictl_utils import get_experiment_port, get_config_filename from .launcher_utils import parse_time from .constants import REST_TIME_OUT, TUNERS_SUPPORTING_IMPORT_DATA, TUNERS_NO_NEED_TO_IMPORT_DATA diff --git a/tools/nni_cmd/url_utils.py b/tools/nni_cmd/url_utils.py index c50b2551d2..05cfa8e66f 100644 --- a/tools/nni_cmd/url_utils.py +++ b/tools/nni_cmd/url_utils.py @@ -18,8 +18,8 @@ # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +import socket import psutil -from socket import AddressFamily BASE_URL = 'http://localhost' @@ -83,8 +83,8 @@ def tensorboard_url(port): def get_local_urls(port): '''get urls of local machine''' url_list = [] - for name, info in psutil.net_if_addrs().items(): + for _, info in psutil.net_if_addrs().items(): for addr in info: - if AddressFamily.AF_INET == addr.family: + if socket.AddressFamily.AF_INET == addr.family: url_list.append('http://{}:{}'.format(addr.address, port)) return url_list From e259d109fea97fbce6f81b3081390fcb99d594fa Mon Sep 17 00:00:00 2001 From: Shinai Yang Date: Thu, 31 Oct 2019 17:55:41 +0800 Subject: [PATCH 2/4] fix id error --- tools/nni_cmd/config_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/nni_cmd/config_utils.py b/tools/nni_cmd/config_utils.py index c5a36b374d..c7c88bcf3e 100644 --- a/tools/nni_cmd/config_utils.py +++ b/tools/nni_cmd/config_utils.py @@ -85,7 +85,7 @@ def add_experiment(self, expId, port, time, file_name, platform): def update_experiment(self, expId, key, value): '''Update experiment''' - if id not in self.experiments: + if expId not in self.experiments: return False self.experiments[expId][key] = value self.write_file() From 3a81e0da760fd178cb78d1d5607ec6e0081f360b Mon Sep 17 00:00:00 2001 From: Shinai Yang Date: Wed, 20 Nov 2019 15:56:28 +0800 Subject: [PATCH 3/4] init --- src/nni_manager/common/log.ts | 6 +----- src/nni_manager/main.ts | 8 ++++++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/nni_manager/common/log.ts b/src/nni_manager/common/log.ts index e2ca62f9c6..275fb76ffe 100644 --- a/src/nni_manager/common/log.ts +++ b/src/nni_manager/common/log.ts @@ -155,11 +155,7 @@ class Logger { } } -function getLogger(fileName?: string): Logger { - component.Container.bind(Logger).provider({ - get: (): Logger => new Logger(fileName) - }); - +function getLogger(): Logger { return component.get(Logger); } diff --git a/src/nni_manager/main.ts b/src/nni_manager/main.ts index fec5a8819e..fb9c8bf6de 100644 --- a/src/nni_manager/main.ts +++ b/src/nni_manager/main.ts @@ -49,7 +49,7 @@ function initStartupInfo( setExperimentStartupInfo(createNew, expId, basePort, logDirectory, experimentLogLevel, readonly); } -async function initContainer(platformMode: string): Promise { +async function initContainer(platformMode: string, logFileName?: string): Promise { if (platformMode === 'local') { Container.bind(TrainingService) .to(LocalTrainingService) @@ -82,6 +82,9 @@ async function initContainer(platformMode: string): Promise { Container.bind(DataStore) .to(NNIDataStore) .scope(Scope.Singleton); + Container.bind(Logger).provider({ + get: (): Logger => new Logger(logFileName) + }); const ds: DataStore = component.get(DataStore); await ds.init(); @@ -145,13 +148,14 @@ initStartupInfo(startMode, experimentId, port, logDir, logLevel, readonly); mkDirP(getLogDir()) .then(async () => { - const log: Logger = getLogger(); try { await initContainer(mode); + const log: Logger = getLogger(); const restServer: NNIRestServer = component.get(NNIRestServer); await restServer.start(); log.info(`Rest server listening on: ${restServer.endPoint}`); } catch (err) { + const log: Logger = getLogger(); log.error(`${err.stack}`); throw err; } From bb13de973555fa0502d6435c054fa30205943137 Mon Sep 17 00:00:00 2001 From: Shinai Yang Date: Wed, 20 Nov 2019 15:58:44 +0800 Subject: [PATCH 4/4] sort code --- src/nni_manager/main.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nni_manager/main.ts b/src/nni_manager/main.ts index fb9c8bf6de..758694be32 100644 --- a/src/nni_manager/main.ts +++ b/src/nni_manager/main.ts @@ -150,9 +150,9 @@ mkDirP(getLogDir()) .then(async () => { try { await initContainer(mode); - const log: Logger = getLogger(); const restServer: NNIRestServer = component.get(NNIRestServer); await restServer.start(); + const log: Logger = getLogger(); log.info(`Rest server listening on: ${restServer.endPoint}`); } catch (err) { const log: Logger = getLogger();