From fc0ff8ced0e6c37d0973e0ecb1ef391e47132074 Mon Sep 17 00:00:00 2001 From: Dalong <39682259+eedalong@users.noreply.github.com> Date: Mon, 30 Nov 2020 17:27:56 +0800 Subject: [PATCH 1/2] fix checkpoint load error and stop updating paramters in evaluation stage (#3124) --- .../pytorch/quantization/quantizers.py | 54 +++++++++++-------- 1 file changed, 33 insertions(+), 21 deletions(-) diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index 689b3c56b2..a24b2bcf71 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -73,9 +73,9 @@ def update_quantization_param(bits, rmin, rmax): ---------- bits : int quantization bits length - rmin : float + rmin : Tensor min value of real value - rmax : float + rmax : Tensor max value of real value Returns @@ -85,12 +85,17 @@ def update_quantization_param(bits, rmin, rmax): # extend the [min, max] interval to ensure that it contains 0. # Otherwise, we would not meet the requirement that 0 be an exactly # representable value. - rmin = min(rmin, 0) - rmax = max(rmax, 0) + if rmin.is_cuda: + rmin = torch.min(rmin, torch.Tensor([0]).cuda()) + rmax = torch.max(rmax, torch.Tensor([0]).cuda()) + qmin = torch.Tensor([0]).cuda() + qmax = torch.Tensor([(1 << bits) - 1]).cuda() + else: + rmin = torch.min(rmin, torch.Tensor([0])) + rmax = torch.max(rmax, torch.Tensor([0])) + qmin = torch.Tensor([0]) + qmax = torch.Tensor([(1 << bits) - 1]) - # the min and max quantized values, as floating-point values - qmin = 0 - qmax = (1 << bits) - 1 # First determine the scale. scale = (rmax - rmin) / (qmax - qmin) @@ -143,11 +148,11 @@ def __init__(self, model, config_list, optimizer=None): types of nn.module you want to apply quantization, eg. 'Conv2d' """ super().__init__(model, config_list, optimizer) - self.steps = 1 modules_to_compress = self.get_modules_to_compress() + self.bound_model.register_buffer("steps", torch.Tensor([1])) for layer, config in modules_to_compress: - layer.module.register_buffer("zero_point", None) - layer.module.register_buffer("scale", None) + layer.module.register_buffer("zero_point", torch.Tensor([0.0])) + layer.module.register_buffer("scale", torch.Tensor([1.0])) if "output" in config.get("quant_types", []): layer.module.register_buffer('ema_decay', torch.Tensor([0.99])) layer.module.register_buffer('tracked_min_biased', torch.zeros(1)) @@ -187,13 +192,17 @@ def _quantize(self, bits, op, real_val): quantization bits length op : torch.nn.Module target module - real_val : float + real_val : Tensor real value to be quantized Returns ------- - float + Tensor """ + if real_val.is_cuda: + op.zero_point = op.zero_point.cuda() + op.scale = op.scale.cuda() + transformed_val = op.zero_point + real_val / op.scale qmin = 0 qmax = (1 << bits) - 1 @@ -229,7 +238,8 @@ def quantize_weight(self, wrapper, **kwargs): quant_start_step = config.get('quant_start_step', 0) assert weight_bits >= 1, "quant bits length should be at least 1" - if quant_start_step > self.steps: + # we dont update weight in evaluation stage + if quant_start_step > self.bound_model.steps or not wrapper.training: return weight # if bias exists, quantize bias to uint32 @@ -258,15 +268,17 @@ def quantize_output(self, output, wrapper, **kwargs): quant_start_step = config.get('quant_start_step', 0) assert output_bits >= 1, "quant bits length should be at least 1" - if quant_start_step > self.steps: + if quant_start_step > self.bound_model.steps: return output - current_min, current_max = torch.min(output), torch.max(output) - module.tracked_min_biased, module.tracked_min = update_ema(module.tracked_min_biased, current_min, - module.ema_decay, self.steps) - module.tracked_max_biased, module.tracked_max = update_ema(module.tracked_max_biased, current_max, - module.ema_decay, self.steps) - module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min, module.tracked_max) + # we dont update output quantization parameters in evaluation stage + if wrapper.training: + current_min, current_max = torch.min(output), torch.max(output) + module.tracked_min_biased, module.tracked_min = update_ema(module.tracked_min_biased, current_min, + module.ema_decay, self.bound_model.steps) + module.tracked_max_biased, module.tracked_max = update_ema(module.tracked_max_biased, current_max, + module.ema_decay, self.bound_model.steps) + module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min, module.tracked_max) out = self._quantize(output_bits, module, output) out = self._dequantize(module, out) return out @@ -279,7 +291,7 @@ def step_with_optimizer(self): """ override `compressor` `step` method, quantization only happens after certain number of steps """ - self.steps += 1 + self.bound_model.steps +=1 class DoReFaQuantizer(Quantizer): From 95f731e481913d1228ae4faa012a463ddced5208 Mon Sep 17 00:00:00 2001 From: J-shang <33053116+J-shang@users.noreply.github.com> Date: Mon, 30 Nov 2020 17:31:22 +0800 Subject: [PATCH 2/2] experiment management backend (#3081) * step 1 nnictl generate experimentId & merge folder * step 2.1 modify .experiment structure * step 2.2 add lock for .experiment rw in nnictl * step 2.2 add filelock dependence * step 2.2 remove uniqueString from main.js * fix test bug * fix test bug * setp 3.1 add experiment manager * step 3.2 add getExperimentsInfo * fix eslint * add a simple file lock to support stale * step 3.3 add test * divide abs experiment manager from manager * experiment manager refactor * support .experiment sync update status * nnictl no longer uses rest api to update status or endtime * nnictl no longer uses rest api to update status or endtime * fix eslint * support .experiment sync update endtime * fix test * fix settimeout bug * fix test * adjust experiment endTime * separate simple file lock class * modify name * add 'id' in .experiment * update rest api format * fix eslint * fix issue in comments * fix rest api format * add indent in json in experiments manager * fix unittest * fix unittest * refector file lock * fix eslint * remove '__enter__' in filelock * filelock support never expire Co-authored-by: Ning Shang --- nni/tools/nnictl/common_utils.py | 36 ++++ nni/tools/nnictl/config_utils.py | 72 +++++--- nni/tools/nnictl/constants.py | 2 +- nni/tools/nnictl/launcher.py | 74 ++++---- nni/tools/nnictl/launcher_utils.py | 4 + nni/tools/nnictl/nnictl_utils.py | 61 +++---- nni/tools/nnictl/tensorboard_utils.py | 7 +- setup.py | 1 + test/ut/tools/nnictl/mock/experiment.py | 4 +- test/ut/tools/nnictl/test_config_utils.py | 2 +- test/ut/tools/nnictl/test_nnictl_utils.py | 2 +- ts/nni_manager/common/experimentManager.ts | 13 ++ ts/nni_manager/common/utils.ts | 34 +++- ts/nni_manager/core/nniExperimentsManager.ts | 171 ++++++++++++++++++ ts/nni_manager/core/nnimanager.ts | 16 +- .../core/test/experimentManager.test.ts | 60 ++++++ ts/nni_manager/core/test/nnimanager.test.ts | 22 ++- ts/nni_manager/main.ts | 16 +- ts/nni_manager/package.json | 2 + ts/nni_manager/rest_server/restHandler.ts | 17 +- .../test/mockedExperimentManager.ts | 44 +++++ .../rest_server/test/restserver.test.ts | 13 ++ 22 files changed, 546 insertions(+), 127 deletions(-) create mode 100644 ts/nni_manager/common/experimentManager.ts create mode 100644 ts/nni_manager/core/nniExperimentsManager.ts create mode 100644 ts/nni_manager/core/test/experimentManager.test.ts create mode 100644 ts/nni_manager/rest_server/test/mockedExperimentManager.ts diff --git a/nni/tools/nnictl/common_utils.py b/nni/tools/nnictl/common_utils.py index 2edbf667df..24d64a9382 100644 --- a/nni/tools/nnictl/common_utils.py +++ b/nni/tools/nnictl/common_utils.py @@ -5,11 +5,14 @@ import sys import json import tempfile +import time import socket import string import random import ruamel.yaml as yaml import psutil +import filelock +import glob from colorama import Fore from .constants import ERROR_INFO, NORMAL_INFO, WARNING_INFO @@ -95,3 +98,36 @@ def generate_folder_name(): temp_dir = generate_folder_name() os.makedirs(temp_dir) return temp_dir + +class SimplePreemptiveLock(filelock.SoftFileLock): + '''this is a lock support check lock expiration, if you do not need check expiration, you can use SoftFileLock''' + def __init__(self, lock_file, stale=-1): + super(__class__, self).__init__(lock_file, timeout=-1) + self._lock_file_name = '{}.{}'.format(self._lock_file, os.getpid()) + self._stale = stale + + def _acquire(self): + open_mode = os.O_WRONLY | os.O_CREAT | os.O_EXCL | os.O_TRUNC + try: + lock_file_names = glob.glob(self._lock_file + '.*') + for file_name in lock_file_names: + if os.path.exists(file_name) and (self._stale < 0 or time.time() - os.stat(file_name).st_mtime < self._stale): + return None + fd = os.open(self._lock_file_name, open_mode) + except (IOError, OSError): + pass + else: + self._lock_file_fd = fd + return None + + def _release(self): + os.close(self._lock_file_fd) + self._lock_file_fd = None + try: + os.remove(self._lock_file_name) + except OSError: + pass + return None + +def get_file_lock(path: string, stale=-1): + return SimplePreemptiveLock(path + '.lock', stale=-1) diff --git a/nni/tools/nnictl/config_utils.py b/nni/tools/nnictl/config_utils.py index 434bdebeca..decb93fb06 100644 --- a/nni/tools/nnictl/config_utils.py +++ b/nni/tools/nnictl/config_utils.py @@ -4,8 +4,10 @@ import os import json import shutil +import time from .constants import NNICTL_HOME_DIR from .command_utils import print_error +from .common_utils import get_file_lock class Config: '''a util class to load and save config''' @@ -34,7 +36,7 @@ def write_file(self): if self.config: try: with open(self.config_file, 'w') as file: - json.dump(self.config, file) + json.dump(self.config, file, indent=4) except IOError as error: print('Error:', error) return @@ -54,39 +56,53 @@ class Experiments: def __init__(self, home_dir=NNICTL_HOME_DIR): os.makedirs(home_dir, exist_ok=True) self.experiment_file = os.path.join(home_dir, '.experiment') - self.experiments = self.read_file() + self.lock = get_file_lock(self.experiment_file, stale=2) + with self.lock: + self.experiments = self.read_file() - def add_experiment(self, expId, port, startTime, file_name, platform, experiment_name, endTime='N/A', status='INITIALIZED'): - '''set {key:value} paris to self.experiment''' - self.experiments[expId] = {} - self.experiments[expId]['port'] = port - self.experiments[expId]['startTime'] = startTime - self.experiments[expId]['endTime'] = endTime - self.experiments[expId]['status'] = status - self.experiments[expId]['fileName'] = file_name - self.experiments[expId]['platform'] = platform - self.experiments[expId]['experimentName'] = experiment_name - self.write_file() + def add_experiment(self, expId, port, startTime, platform, experiment_name, endTime='N/A', status='INITIALIZED', + tag=[], pid=None, webuiUrl=[], logDir=[]): + '''set {key:value} pairs to self.experiment''' + with self.lock: + self.experiments = self.read_file() + self.experiments[expId] = {} + self.experiments[expId]['id'] = expId + self.experiments[expId]['port'] = port + self.experiments[expId]['startTime'] = startTime + self.experiments[expId]['endTime'] = endTime + self.experiments[expId]['status'] = status + self.experiments[expId]['platform'] = platform + self.experiments[expId]['experimentName'] = experiment_name + self.experiments[expId]['tag'] = tag + self.experiments[expId]['pid'] = pid + self.experiments[expId]['webuiUrl'] = webuiUrl + self.experiments[expId]['logDir'] = logDir + self.write_file() def update_experiment(self, expId, key, value): '''Update experiment''' - if expId not in self.experiments: - return False - self.experiments[expId][key] = value - self.write_file() - return True + with self.lock: + self.experiments = self.read_file() + if expId not in self.experiments: + return False + self.experiments[expId][key] = value + self.write_file() + return True def remove_experiment(self, expId): '''remove an experiment by id''' - if expId in self.experiments: - fileName = self.experiments.pop(expId).get('fileName') - if fileName: - logPath = os.path.join(NNICTL_HOME_DIR, fileName) - try: - shutil.rmtree(logPath) - except FileNotFoundError: - print_error('{0} does not exist.'.format(logPath)) - self.write_file() + with self.lock: + self.experiments = self.read_file() + if expId in self.experiments: + self.experiments.pop(expId) + fileName = expId + if fileName: + logPath = os.path.join(NNICTL_HOME_DIR, fileName) + try: + shutil.rmtree(logPath) + except FileNotFoundError: + print_error('{0} does not exist.'.format(logPath)) + self.write_file() def get_all_experiments(self): '''return all of experiments''' @@ -96,7 +112,7 @@ def write_file(self): '''save config to local file''' try: with open(self.experiment_file, 'w') as file: - json.dump(self.experiments, file) + json.dump(self.experiments, file, indent=4) except IOError as error: print('Error:', error) return '' diff --git a/nni/tools/nnictl/constants.py b/nni/tools/nnictl/constants.py index f64f93b289..19b2ccc366 100644 --- a/nni/tools/nnictl/constants.py +++ b/nni/tools/nnictl/constants.py @@ -4,7 +4,7 @@ import os from colorama import Fore -NNICTL_HOME_DIR = os.path.join(os.path.expanduser('~'), '.local', 'nnictl') +NNICTL_HOME_DIR = os.path.join(os.path.expanduser('~'), 'nni-experiments') NNI_HOME_DIR = os.path.join(os.path.expanduser('~'), 'nni-experiments') diff --git a/nni/tools/nnictl/launcher.py b/nni/tools/nnictl/launcher.py index 6a199ce88a..d82ee7c326 100644 --- a/nni/tools/nnictl/launcher.py +++ b/nni/tools/nnictl/launcher.py @@ -23,10 +23,11 @@ from .command_utils import check_output_command, kill_command from .nnictl_utils import update_experiment -def get_log_path(config_file_name): +def get_log_path(experiment_id): '''generate stdout and stderr log path''' - stdout_full_path = os.path.join(NNICTL_HOME_DIR, config_file_name, 'stdout') - stderr_full_path = os.path.join(NNICTL_HOME_DIR, config_file_name, 'stderr') + os.makedirs(os.path.join(NNICTL_HOME_DIR, experiment_id, 'log'), exist_ok=True) + stdout_full_path = os.path.join(NNICTL_HOME_DIR, experiment_id, 'log', 'nnictl_stdout.log') + stderr_full_path = os.path.join(NNICTL_HOME_DIR, experiment_id, 'log', 'nnictl_stderr.log') return stdout_full_path, stderr_full_path def print_log_content(config_file_name): @@ -38,7 +39,7 @@ def print_log_content(config_file_name): print_normal(' Stderr:') print(check_output_command(stderr_full_path)) -def start_rest_server(port, platform, mode, config_file_name, foreground=False, experiment_id=None, log_dir=None, log_level=None): +def start_rest_server(port, platform, mode, experiment_id, foreground=False, log_dir=None, log_level=None): '''Run nni manager process''' if detect_port(port): print_error('Port %s is used by another process, please reset the port!\n' \ @@ -63,7 +64,8 @@ def start_rest_server(port, platform, mode, config_file_name, foreground=False, node_command = os.path.join(entry_dir, 'node.exe') else: node_command = os.path.join(entry_dir, 'node') - cmds = [node_command, '--max-old-space-size=4096', entry_file, '--port', str(port), '--mode', platform] + cmds = [node_command, '--max-old-space-size=4096', entry_file, '--port', str(port), '--mode', platform, \ + '--experiment_id', experiment_id] if mode == 'view': cmds += ['--start_mode', 'resume'] cmds += ['--readonly', 'true'] @@ -73,13 +75,12 @@ def start_rest_server(port, platform, mode, config_file_name, foreground=False, cmds += ['--log_dir', log_dir] if log_level is not None: cmds += ['--log_level', log_level] - if mode in ['resume', 'view']: - cmds += ['--experiment_id', experiment_id] if foreground: cmds += ['--foreground', 'true'] - stdout_full_path, stderr_full_path = get_log_path(config_file_name) + stdout_full_path, stderr_full_path = get_log_path(experiment_id) with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file: - time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) + start_time = time.time() + time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time)) #add time information in the header of log files log_header = LOG_HEADER % str(time_now) stdout_file.write(log_header) @@ -95,7 +96,7 @@ def start_rest_server(port, platform, mode, config_file_name, foreground=False, process = Popen(cmds, cwd=entry_dir, stdout=PIPE, stderr=PIPE) else: process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file) - return process, str(time_now) + return process, int(start_time * 1000) def set_trial_config(experiment_config, port, config_file_name): '''set trial configuration''' @@ -432,9 +433,9 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res raise Exception(ERROR_INFO % 'Rest server stopped!') exit(1) -def launch_experiment(args, experiment_config, mode, config_file_name, experiment_id=None): +def launch_experiment(args, experiment_config, mode, experiment_id): '''follow steps to start rest server and start experiment''' - nni_config = Config(config_file_name) + nni_config = Config(experiment_id) # check packages for tuner package_name, module_name = None, None if experiment_config.get('tuner') and experiment_config['tuner'].get('builtinTunerName'): @@ -445,15 +446,15 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen module_name, _ = get_builtin_module_class_name('advisors', package_name) if package_name and module_name: try: - stdout_full_path, stderr_full_path = get_log_path(config_file_name) + stdout_full_path, stderr_full_path = get_log_path(experiment_id) with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file: check_call([sys.executable, '-c', 'import %s'%(module_name)], stdout=stdout_file, stderr=stderr_file) except CalledProcessError: print_error('some errors happen when import package %s.' %(package_name)) - print_log_content(config_file_name) + print_log_content(experiment_id) if package_name in INSTALLABLE_PACKAGE_META: print_error('If %s is not installed, it should be installed through '\ - '\'nnictl package install --name %s\''%(package_name, package_name)) + '\'nnictl package install --name %s\'' % (package_name, package_name)) exit(1) log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else None log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None @@ -465,7 +466,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen log_level = 'debug' # start rest server rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \ - mode, config_file_name, foreground, experiment_id, log_dir, log_level) + mode, experiment_id, foreground, log_dir, log_level) nni_config.set_config('restServerPid', rest_process.pid) # Deal with annotation if experiment_config.get('useAnnotation'): @@ -491,7 +492,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen print_normal('Successfully started Restful server!') else: print_error('Restful server start failed!') - print_log_content(config_file_name) + print_log_content(experiment_id) try: kill_command(rest_process.pid) except Exception: @@ -500,21 +501,25 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen if mode != 'view': # set platform configuration set_platform_config(experiment_config['trainingServicePlatform'], experiment_config, args.port,\ - config_file_name, rest_process) + experiment_id, rest_process) # start a new experiment print_normal('Starting experiment...') + # save experiment information + nnictl_experiment_config = Experiments() + nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time, + experiment_config['trainingServicePlatform'], + experiment_config['experimentName'], pid=rest_process.pid, logDir=log_dir) # set debug configuration if mode != 'view' and experiment_config.get('debug') is None: experiment_config['debug'] = args.debug - response = set_experiment(experiment_config, mode, args.port, config_file_name) + response = set_experiment(experiment_config, mode, args.port, experiment_id) if response: if experiment_id is None: experiment_id = json.loads(response.text).get('experiment_id') - nni_config.set_config('experimentId', experiment_id) else: print_error('Start experiment failed!') - print_log_content(config_file_name) + print_log_content(experiment_id) try: kill_command(rest_process.pid) except Exception: @@ -526,12 +531,6 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen web_ui_url_list = get_local_urls(args.port) nni_config.set_config('webuiUrl', web_ui_url_list) - # save experiment information - nnictl_experiment_config = Experiments() - nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time, config_file_name, - experiment_config['trainingServicePlatform'], - experiment_config['experimentName']) - print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list))) if mode != 'view' and args.foreground: try: @@ -544,8 +543,9 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen def create_experiment(args): '''start a new experiment''' - config_file_name = ''.join(random.sample(string.ascii_letters + string.digits, 8)) - nni_config = Config(config_file_name) + experiment_id = ''.join(random.sample(string.ascii_letters + string.digits, 8)) + nni_config = Config(experiment_id) + nni_config.set_config('experimentId', experiment_id) config_path = os.path.abspath(args.config) if not os.path.exists(config_path): print_error('Please set correct config path!') @@ -560,9 +560,9 @@ def create_experiment(args): nni_config.set_config('experimentConfig', experiment_config) nni_config.set_config('restServerPort', args.port) try: - launch_experiment(args, experiment_config, 'new', config_file_name) + launch_experiment(args, experiment_config, 'new', experiment_id) except Exception as exception: - nni_config = Config(config_file_name) + nni_config = Config(experiment_id) restServerPid = nni_config.get_config('restServerPid') if restServerPid: kill_command(restServerPid) @@ -589,17 +589,13 @@ def manage_stopped_experiment(args, mode): exit(1) experiment_id = args.id print_normal('{0} experiment {1}...'.format(mode, experiment_id)) - nni_config = Config(experiment_dict[experiment_id]['fileName']) + nni_config = Config(experiment_id) experiment_config = nni_config.get_config('experimentConfig') - experiment_id = nni_config.get_config('experimentId') - new_config_file_name = ''.join(random.sample(string.ascii_letters + string.digits, 8)) - new_nni_config = Config(new_config_file_name) - new_nni_config.set_config('experimentConfig', experiment_config) - new_nni_config.set_config('restServerPort', args.port) + nni_config.set_config('restServerPort', args.port) try: - launch_experiment(args, experiment_config, mode, new_config_file_name, experiment_id) + launch_experiment(args, experiment_config, mode, experiment_id) except Exception as exception: - nni_config = Config(new_config_file_name) + nni_config = Config(experiment_id) restServerPid = nni_config.get_config('restServerPid') if restServerPid: kill_command(restServerPid) diff --git a/nni/tools/nnictl/launcher_utils.py b/nni/tools/nnictl/launcher_utils.py index 48dd2779a3..7dcfa8c57e 100644 --- a/nni/tools/nnictl/launcher_utils.py +++ b/nni/tools/nnictl/launcher_utils.py @@ -32,6 +32,8 @@ def parse_time(time): def parse_path(experiment_config, config_path): '''Parse path in config file''' expand_path(experiment_config, 'searchSpacePath') + if experiment_config.get('logDir'): + expand_path(experiment_config, 'logDir') if experiment_config.get('trial'): expand_path(experiment_config['trial'], 'codeDir') if experiment_config['trial'].get('authFile'): @@ -65,6 +67,8 @@ def parse_path(experiment_config, config_path): root_path = os.path.dirname(config_path) if experiment_config.get('searchSpacePath'): parse_relative_path(root_path, experiment_config, 'searchSpacePath') + if experiment_config.get('logDir'): + parse_relative_path(root_path, experiment_config, 'logDir') if experiment_config.get('trial'): parse_relative_path(root_path, experiment_config['trial'], 'codeDir') if experiment_config['trial'].get('authFile'): diff --git a/nni/tools/nnictl/nnictl_utils.py b/nni/tools/nnictl/nnictl_utils.py index f2e487288b..9428cab240 100644 --- a/nni/tools/nnictl/nnictl_utils.py +++ b/nni/tools/nnictl/nnictl_utils.py @@ -30,7 +30,7 @@ def get_experiment_time(port): '''get the startTime and endTime of an experiment''' response = rest_get(experiment_url(port), REST_TIME_OUT) if response and check_response(response): - content = convert_time_stamp_to_date(json.loads(response.text)) + content = json.loads(response.text) return content.get('startTime'), content.get('endTime') return None, None @@ -50,20 +50,11 @@ def update_experiment(): for key in experiment_dict.keys(): if isinstance(experiment_dict[key], dict): if experiment_dict[key].get('status') != 'STOPPED': - nni_config = Config(experiment_dict[key]['fileName']) + nni_config = Config(key) rest_pid = nni_config.get_config('restServerPid') if not detect_process(rest_pid): experiment_config.update_experiment(key, 'status', 'STOPPED') continue - rest_port = nni_config.get_config('restServerPort') - startTime, endTime = get_experiment_time(rest_port) - if startTime: - experiment_config.update_experiment(key, 'startTime', startTime) - if endTime: - experiment_config.update_experiment(key, 'endTime', endTime) - status = get_experiment_status(rest_port) - if status: - experiment_config.update_experiment(key, 'status', status) def check_experiment_id(args, update=True): '''check if the id is valid @@ -184,9 +175,7 @@ def get_config_filename(args): if experiment_id is None: print_error('Please set correct experiment id.') exit(1) - experiment_config = Experiments() - experiment_dict = experiment_config.get_all_experiments() - return experiment_dict[experiment_id]['fileName'] + return experiment_id def get_experiment_port(args): '''get the port of experiment''' @@ -228,11 +217,9 @@ def stop_experiment(args): exit(1) experiment_id_list = parse_ids(args) if experiment_id_list: - experiment_config = Experiments() - experiment_dict = experiment_config.get_all_experiments() for experiment_id in experiment_id_list: print_normal('Stopping experiment %s' % experiment_id) - nni_config = Config(experiment_dict[experiment_id]['fileName']) + nni_config = Config(experiment_id) rest_pid = nni_config.get_config('restServerPid') if rest_pid: kill_command(rest_pid) @@ -245,9 +232,6 @@ def stop_experiment(args): print_error(exception) nni_config.set_config('tensorboardPidList', []) print_normal('Stop experiment success.') - experiment_config.update_experiment(experiment_id, 'status', 'STOPPED') - time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) - experiment_config.update_experiment(experiment_id, 'endTime', str(time_now)) def trial_ls(args): '''List trial''' @@ -553,7 +537,7 @@ def experiment_clean(args): else: break for experiment_id in experiment_id_list: - nni_config = Config(experiment_dict[experiment_id]['fileName']) + nni_config = Config(experiment_id) platform = nni_config.get_config('experimentConfig').get('trainingServicePlatform') experiment_id = nni_config.get_config('experimentId') if platform == 'remote': @@ -668,18 +652,15 @@ def experiment_list(args): experiment_dict[key]['status'], experiment_dict[key]['port'], experiment_dict[key].get('platform'), - experiment_dict[key]['startTime'], - experiment_dict[key]['endTime']) + time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'], + time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['endTime'] / 1000)) if isinstance(experiment_dict[key]['endTime'], int) else experiment_dict[key]['endTime']) print(EXPERIMENT_INFORMATION_FORMAT % experiment_information) return experiment_id_list def get_time_interval(time1, time2): '''get the interval of two times''' try: - #convert time to timestamp - time1 = time.mktime(time.strptime(time1, '%Y/%m/%d %H:%M:%S')) - time2 = time.mktime(time.strptime(time2, '%Y/%m/%d %H:%M:%S')) - seconds = (datetime.fromtimestamp(time2) - datetime.fromtimestamp(time1)).seconds + seconds = int((time2 - time1) / 1000) #convert seconds to day:hour:minute:second days = seconds / 86400 seconds %= 86400 @@ -708,8 +689,8 @@ def show_experiment_info(): return for key in experiment_id_list: print(EXPERIMENT_MONITOR_INFO % (key, experiment_dict[key]['status'], experiment_dict[key]['port'], \ - experiment_dict[key].get('platform'), experiment_dict[key]['startTime'], \ - get_time_interval(experiment_dict[key]['startTime'], experiment_dict[key]['endTime']))) + experiment_dict[key].get('platform'), time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'], \ + get_time_interval(experiment_dict[key]['startTime'], experiment_dict[key]['endTime']))) print(TRIAL_MONITOR_HEAD) running, response = check_rest_server_quick(experiment_dict[key]['port']) if running: @@ -850,7 +831,7 @@ def save_experiment(args): print_error('Can only save stopped experiment!') exit(1) print_normal('Saving...') - nni_config = Config(experiment_dict[args.id]['fileName']) + nni_config = Config(args.id) logDir = os.path.join(NNI_HOME_DIR, args.id) if nni_config.get_config('logDir'): logDir = os.path.join(nni_config.get_config('logDir'), args.id) @@ -873,8 +854,8 @@ def save_experiment(args): except IOError: print_error('Write file to %s failed!' % os.path.join(temp_nnictl_dir, '.experiment')) exit(1) - nnictl_config_dir = os.path.join(NNICTL_HOME_DIR, experiment_dict[args.id]['fileName']) - shutil.copytree(nnictl_config_dir, os.path.join(temp_nnictl_dir, experiment_dict[args.id]['fileName'])) + nnictl_config_dir = os.path.join(NNICTL_HOME_DIR, args.id) + shutil.copytree(nnictl_config_dir, os.path.join(temp_nnictl_dir, args.id)) # Step3. Copy code dir if args.saveCodeDir: @@ -947,20 +928,20 @@ def load_experiment(args): print_error('Invalid: experiment id already exist!') shutil.rmtree(temp_root_dir) exit(1) - if not os.path.exists(os.path.join(nnictl_temp_dir, experiment_metadata.get('fileName'))): + if not os.path.exists(os.path.join(nnictl_temp_dir, experiment_id)): print_error('Invalid: experiment metadata does not exist!') shutil.rmtree(temp_root_dir) exit(1) # Step2. Copy nnictl metadata - src_path = os.path.join(nnictl_temp_dir, experiment_metadata.get('fileName')) - dest_path = os.path.join(NNICTL_HOME_DIR, experiment_metadata.get('fileName')) + src_path = os.path.join(nnictl_temp_dir, experiment_id) + dest_path = os.path.join(NNICTL_HOME_DIR, experiment_id) if os.path.exists(dest_path): shutil.rmtree(dest_path) shutil.copytree(src_path, dest_path) # Step3. Copy experiment data - nni_config = Config(experiment_metadata.get('fileName')) + nni_config = Config(experiment_id) nnictl_exp_config = nni_config.get_config('experimentConfig') if args.logDir: logDir = args.logDir @@ -1027,13 +1008,15 @@ def load_experiment(args): experiment_config.add_experiment(experiment_id, experiment_metadata.get('port'), experiment_metadata.get('startTime'), - experiment_metadata.get('fileName'), experiment_metadata.get('platform'), experiment_metadata.get('experimentName'), experiment_metadata.get('endTime'), - experiment_metadata.get('status')) + experiment_metadata.get('status'), + experiment_metadata.get('tag'), + experiment_metadata.get('pid'), + experiment_metadata.get('webUrl'), + experiment_metadata.get('logDir')) print_normal('Load experiment %s succsss!' % experiment_id) # Step6. Cleanup temp data shutil.rmtree(temp_root_dir) - diff --git a/nni/tools/nnictl/tensorboard_utils.py b/nni/tools/nnictl/tensorboard_utils.py index 9bc8e14e48..0e1b75ddd0 100644 --- a/nni/tools/nnictl/tensorboard_utils.py +++ b/nni/tools/nnictl/tensorboard_utils.py @@ -11,7 +11,7 @@ from .url_utils import trial_jobs_url, get_local_urls from .constants import REST_TIME_OUT from .common_utils import print_normal, print_warning, print_error, print_green, detect_process, detect_port, check_tensorboard_version -from .nnictl_utils import check_experiment_id, check_experiment_id +from .nnictl_utils import check_experiment_id from .ssh_utils import create_ssh_sftp_client, copy_remote_directory_to_local def parse_log_path(args, trial_content): @@ -95,8 +95,7 @@ def stop_tensorboard(args): experiment_id = check_experiment_id(args) experiment_config = Experiments() experiment_dict = experiment_config.get_all_experiments() - config_file_name = experiment_dict[experiment_id]['fileName'] - nni_config = Config(config_file_name) + nni_config = Config(experiment_id) tensorboard_pid_list = nni_config.get_config('tensorboardPidList') if tensorboard_pid_list: for tensorboard_pid in tensorboard_pid_list: @@ -136,7 +135,7 @@ def start_tensorboard(args): print_error("Experiment {} is stopped...".format(args.id)) return config_file_name = experiment_dict[experiment_id]['fileName'] - nni_config = Config(config_file_name) + nni_config = Config(args.id) if nni_config.get_config('experimentConfig').get('trainingServicePlatform') == 'adl': adl_tensorboard_helper(args) return diff --git a/setup.py b/setup.py index 038dc6d4ae..ee93d1d041 100644 --- a/setup.py +++ b/setup.py @@ -73,6 +73,7 @@ 'scikit-learn>=0.23.2', 'pkginfo', 'websockets', + 'filelock', 'prettytable' ] diff --git a/test/ut/tools/nnictl/mock/experiment.py b/test/ut/tools/nnictl/mock/experiment.py index cbfdcf45ff..49173d481c 100644 --- a/test/ut/tools/nnictl/mock/experiment.py +++ b/test/ut/tools/nnictl/mock/experiment.py @@ -11,9 +11,9 @@ def create_mock_experiment(): nnictl_experiment_config = Experiments() - nnictl_experiment_config.add_experiment('xOpEwA5w', '8080', '1970/01/1 01:01:01', 'aGew0x', + nnictl_experiment_config.add_experiment('xOpEwA5w', '8080', 123456, 'local', 'example_sklearn-classification') - nni_config = Config('aGew0x') + nni_config = Config('xOpEwA5w') # mock process cmds = ['sleep', '3600000'] process = Popen(cmds, stdout=PIPE, stderr=STDOUT) diff --git a/test/ut/tools/nnictl/test_config_utils.py b/test/ut/tools/nnictl/test_config_utils.py index df970b80f4..23a8cd96c7 100644 --- a/test/ut/tools/nnictl/test_config_utils.py +++ b/test/ut/tools/nnictl/test_config_utils.py @@ -19,7 +19,7 @@ class CommonUtilsTestCase(TestCase): def test_update_experiment(self): experiment = Experiments(HOME_PATH) - experiment.add_experiment('xOpEwA5w', 8081, 'N/A', 'aGew0x', 'local', 'test', endTime='N/A', status='INITIALIZED') + experiment.add_experiment('xOpEwA5w', 8081, 'N/A', 'local', 'test', endTime='N/A', status='INITIALIZED') self.assertTrue('xOpEwA5w' in experiment.get_all_experiments()) experiment.remove_experiment('xOpEwA5w') self.assertFalse('xOpEwA5w' in experiment.get_all_experiments()) diff --git a/test/ut/tools/nnictl/test_nnictl_utils.py b/test/ut/tools/nnictl/test_nnictl_utils.py index a7f16c453d..7f2d386ce3 100644 --- a/test/ut/tools/nnictl/test_nnictl_utils.py +++ b/test/ut/tools/nnictl/test_nnictl_utils.py @@ -46,7 +46,7 @@ def test_parse_ids(self): @responses.activate def test_get_config_file_name(self): args = generate_args() - self.assertEqual('aGew0x', get_config_filename(args)) + self.assertEqual('xOpEwA5w', get_config_filename(args)) @responses.activate def test_get_experiment_port(self): diff --git a/ts/nni_manager/common/experimentManager.ts b/ts/nni_manager/common/experimentManager.ts new file mode 100644 index 0000000000..0baf0a54ed --- /dev/null +++ b/ts/nni_manager/common/experimentManager.ts @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +'use strict'; + +abstract class ExperimentManager { + public abstract getExperimentsInfo(): Promise; + public abstract setExperimentPath(newPath: string): void; + public abstract setExperimentInfo(experimentId: string, key: string, value: any): void; + public abstract stop(): Promise; +} + +export {ExperimentManager}; diff --git a/ts/nni_manager/common/utils.ts b/ts/nni_manager/common/utils.ts index 99c2d4e0c3..96039986e8 100644 --- a/ts/nni_manager/common/utils.ts +++ b/ts/nni_manager/common/utils.ts @@ -11,13 +11,16 @@ import { ChildProcess, spawn, StdioOptions } from 'child_process'; import * as fs from 'fs'; import * as os from 'os'; import * as path from 'path'; +import * as lockfile from 'lockfile'; import { Deferred } from 'ts-deferred'; import { Container } from 'typescript-ioc'; import * as util from 'util'; +import * as glob from 'glob'; import { Database, DataStore } from './datastore'; import { ExperimentStartupInfo, getExperimentStartupInfo, setExperimentStartupInfo } from './experimentStartupInfo'; import { ExperimentParams, Manager } from './manager'; +import { ExperimentManager } from './experimentManager'; import { HyperParameters, TrainingService, TrialJobStatus } from './trainingService'; import { logLevelNameMap } from './log'; @@ -43,6 +46,10 @@ function getCheckpointDir(): string { return path.join(getExperimentRootDir(), 'checkpoint'); } +function getExperimentsInfoPath(): string { + return path.join(os.homedir(), 'nni-experiments', '.experiment'); +} + function mkDirP(dirPath: string): Promise { const deferred: Deferred = new Deferred(); fs.exists(dirPath, (exists: boolean) => { @@ -184,6 +191,7 @@ function prepareUnitTest(): void { Container.snapshot(DataStore); Container.snapshot(TrainingService); Container.snapshot(Manager); + Container.snapshot(ExperimentManager); const logLevel: string = parseArg(['--log_level', '-ll']); if (logLevel.length > 0 && !logLevelNameMap.has(logLevel)) { @@ -211,6 +219,7 @@ function cleanupUnitTest(): void { Container.restore(DataStore); Container.restore(Database); Container.restore(ExperimentStartupInfo); + Container.restore(ExperimentManager); } let cachedipv4Address: string = ''; @@ -416,8 +425,29 @@ function unixPathJoin(...paths: any[]): string { return dir; } +/** + * lock a file sync + */ +function withLockSync(func: Function, filePath: string, lockOpts: {[key: string]: any}, ...args: any): any { + const lockName = path.join(path.dirname(filePath), path.basename(filePath) + `.lock.${process.pid}`); + if (typeof lockOpts.stale === 'number'){ + const lockPath = path.join(path.dirname(filePath), path.basename(filePath) + '.lock.*'); + const lockFileNames: string[] = glob.sync(lockPath); + const canLock: boolean = lockFileNames.map((fileName) => { + return fs.existsSync(fileName) && Date.now() - fs.statSync(fileName).mtimeMs > lockOpts.stale; + }).filter(isExpired=>isExpired === false).length === 0; + if (!canLock) { + throw new Error('File has been locked.'); + } + } + lockfile.lockSync(lockName, lockOpts); + const result = func(...args); + lockfile.unlockSync(lockName); + return result; +} + export { - countFilesRecursively, validateFileNameRecursively, generateParamFileName, getMsgDispatcherCommand, getCheckpointDir, - getLogDir, getExperimentRootDir, getJobCancelStatus, getDefaultDatabaseDir, getIPV4Address, unixPathJoin, + countFilesRecursively, validateFileNameRecursively, generateParamFileName, getMsgDispatcherCommand, getCheckpointDir, getExperimentsInfoPath, + getLogDir, getExperimentRootDir, getJobCancelStatus, getDefaultDatabaseDir, getIPV4Address, unixPathJoin, withLockSync, mkDirP, mkDirPSync, delay, prepareUnitTest, parseArg, cleanupUnitTest, uniqueString, randomInt, randomSelect, getLogLevel, getVersion, getCmdPy, getTunerProc, isAlive, killPid, getNewLine }; diff --git a/ts/nni_manager/core/nniExperimentsManager.ts b/ts/nni_manager/core/nniExperimentsManager.ts new file mode 100644 index 0000000000..939f7f6e52 --- /dev/null +++ b/ts/nni_manager/core/nniExperimentsManager.ts @@ -0,0 +1,171 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +'use strict'; + +import * as fs from 'fs'; +import * as os from 'os'; +import * as path from 'path'; +import * as assert from 'assert'; + +import { getLogger, Logger } from '../common/log'; +import { isAlive, withLockSync, getExperimentsInfoPath, delay } from '../common/utils'; +import { ExperimentManager } from '../common/experimentManager'; +import { Deferred } from 'ts-deferred'; + +interface CrashedInfo { + experimentId: string; + isCrashed: boolean; +} + +interface FileInfo { + buffer: Buffer; + mtime: number; +} + +class NNIExperimentsManager implements ExperimentManager { + private experimentsPath: string; + private log: Logger; + private profileUpdateTimer: {[key: string]: any}; + + constructor() { + this.experimentsPath = getExperimentsInfoPath(); + this.log = getLogger(); + this.profileUpdateTimer = {}; + } + + public async getExperimentsInfo(): Promise { + const fileInfo: FileInfo = await this.withLockIterated(this.readExperimentsInfo, 100); + const experimentsInformation = JSON.parse(fileInfo.buffer.toString()); + const expIdList: Array = Object.keys(experimentsInformation).filter((expId) => { + return experimentsInformation[expId]['status'] !== 'STOPPED'; + }); + const updateList: Array = (await Promise.all(expIdList.map((expId) => { + return this.checkCrashed(expId, experimentsInformation[expId]['pid']); + }))).filter(crashedInfo => crashedInfo.isCrashed); + if (updateList.length > 0){ + const result = await this.withLockIterated(this.updateAllStatus, 100, updateList.map(crashedInfo => crashedInfo.experimentId), fileInfo.mtime); + if (result !== undefined) { + return JSON.parse(JSON.stringify(Object.keys(result).map(key=>result[key]))); + } else { + await delay(500); + return await this.getExperimentsInfo(); + } + } else { + return JSON.parse(JSON.stringify(Object.keys(experimentsInformation).map(key=>experimentsInformation[key]))); + } + } + + public setExperimentPath(newPath: string): void { + if (newPath[0] === '~') { + newPath = path.join(os.homedir(), newPath.slice(1)); + } + if (!path.isAbsolute(newPath)) { + newPath = path.resolve(newPath); + } + this.log.info(`Set new experiment information path: ${newPath}`); + this.experimentsPath = newPath; + } + + public setExperimentInfo(experimentId: string, key: string, value: any): void { + try { + if (this.profileUpdateTimer[key] !== undefined) { + // if a new call with the same timerId occurs, destroy the unfinished old one + clearTimeout(this.profileUpdateTimer[key]); + this.profileUpdateTimer[key] = undefined; + } + this.withLockSync(() => { + const experimentsInformation = JSON.parse(fs.readFileSync(this.experimentsPath).toString()); + assert(experimentId in experimentsInformation, `Experiment Manager: Experiment Id ${experimentId} not found, this should not happen`); + experimentsInformation[experimentId][key] = value; + fs.writeFileSync(this.experimentsPath, JSON.stringify(experimentsInformation, null, 4)); + }); + } catch (err) { + this.log.error(err); + this.log.debug(`Experiment Manager: Retry set key value: ${experimentId} {${key}: ${value}}`); + if (err.code === 'EEXIST' || err.message === 'File has been locked.') { + this.profileUpdateTimer[key] = setTimeout(this.setExperimentInfo.bind(this), 100, experimentId, key, value); + } + } + } + + private async withLockIterated (func: Function, retry: number, ...args: any): Promise { + if (retry < 0) { + throw new Error('Lock file out of retries.'); + } + try { + return this.withLockSync(func, ...args); + } catch(err) { + if (err.code === 'EEXIST' || err.message === 'File has been locked.') { + // retry wait is 50ms + await delay(50); + return await this.withLockIterated(func, retry - 1, ...args); + } + throw err; + } + } + + private withLockSync (func: Function, ...args: any): any { + return withLockSync(func.bind(this), this.experimentsPath, {stale: 2 * 1000}, ...args); + } + + private readExperimentsInfo(): FileInfo { + const buffer: Buffer = fs.readFileSync(this.experimentsPath); + const mtime: number = fs.statSync(this.experimentsPath).mtimeMs; + return {buffer: buffer, mtime: mtime}; + } + + private async checkCrashed(expId: string, pid: number): Promise { + const alive: boolean = await isAlive(pid); + return {experimentId: expId, isCrashed: !alive} + } + + private updateAllStatus(updateList: Array, timestamp: number): {[key: string]: any} | undefined { + if (timestamp !== fs.statSync(this.experimentsPath).mtimeMs) { + return; + } else { + const experimentsInformation = JSON.parse(fs.readFileSync(this.experimentsPath).toString()); + updateList.forEach((expId: string) => { + if (experimentsInformation[expId]) { + experimentsInformation[expId]['status'] = 'STOPPED'; + } else { + this.log.error(`Experiment Manager: Experiment Id ${expId} not found, this should not happen`); + } + }); + fs.writeFileSync(this.experimentsPath, JSON.stringify(experimentsInformation, null, 4)); + return experimentsInformation; + } + } + + public async stop(): Promise { + this.log.debug('Stopping experiment manager.'); + await this.cleanUp().catch(err=>this.log.error(err.message)); + this.log.debug('Experiment manager stopped.'); + } + + private async cleanUp(): Promise { + const deferred = new Deferred(); + if (this.isUndone()) { + this.log.debug('Experiment manager: something undone'); + setTimeout(((deferred: Deferred): void => { + if (this.isUndone()) { + deferred.reject(new Error('Still has undone after 5s, forced stop.')); + } else { + deferred.resolve(); + } + }).bind(this), 5 * 1000, deferred); + } else { + this.log.debug('Experiment manager: all clean up'); + deferred.resolve(); + } + return deferred.promise; + } + + private isUndone(): boolean { + return Object.keys(this.profileUpdateTimer).filter((key: string) => { + return this.profileUpdateTimer[key] !== undefined; + }).length > 0; + } +} + +export { NNIExperimentsManager }; diff --git a/ts/nni_manager/core/nnimanager.ts b/ts/nni_manager/core/nnimanager.ts index e9cd44e20a..379da5cca2 100644 --- a/ts/nni_manager/core/nnimanager.ts +++ b/ts/nni_manager/core/nnimanager.ts @@ -15,6 +15,7 @@ import { ExperimentParams, ExperimentProfile, Manager, ExperimentStatus, NNIManagerStatus, ProfileUpdateType, TrialJobStatistics } from '../common/manager'; +import { ExperimentManager } from '../common/experimentManager'; import { TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, TrialJobStatus, LogType } from '../common/trainingService'; @@ -31,6 +32,7 @@ import { createDispatcherInterface, IpcInterface } from './ipcInterface'; class NNIManager implements Manager { private trainingService: TrainingService; private dispatcher: IpcInterface | undefined; + private experimentManager: ExperimentManager; private currSubmittedTrialNum: number; // need to be recovered private trialConcurrencyChange: number; // >0: increase, <0: decrease private log: Logger; @@ -49,6 +51,7 @@ class NNIManager implements Manager { this.currSubmittedTrialNum = 0; this.trialConcurrencyChange = 0; this.trainingService = component.get(TrainingService); + this.experimentManager = component.get(ExperimentManager); assert(this.trainingService); this.dispatcherPid = 0; this.waitingTrials = []; @@ -467,7 +470,9 @@ class NNIManager implements Manager { } } await this.trainingService.cleanUp(); - this.experimentProfile.endTime = Date.now(); + if (this.experimentProfile.endTime === undefined) { + this.setEndtime(); + } await this.storeExperimentProfile(); this.setStatus('STOPPED'); } @@ -596,7 +601,7 @@ class NNIManager implements Manager { assert(allFinishedTrialJobNum <= waitSubmittedToFinish); if (allFinishedTrialJobNum >= waitSubmittedToFinish) { this.setStatus('DONE'); - this.experimentProfile.endTime = Date.now(); + this.setEndtime(); await this.storeExperimentProfile(); // write this log for travis CI this.log.info('Experiment done.'); @@ -796,6 +801,7 @@ class NNIManager implements Manager { this.log.error(err.stack); } this.status.errors.push(err.message); + this.setEndtime(); this.setStatus('ERROR'); } @@ -803,9 +809,15 @@ class NNIManager implements Manager { if (status !== this.status.status) { this.log.info(`Change NNIManager status from: ${this.status.status} to: ${status}`); this.status.status = status; + this.experimentManager.setExperimentInfo(this.experimentProfile.id, 'status', this.status.status); } } + private setEndtime(): void { + this.experimentProfile.endTime = Date.now(); + this.experimentManager.setExperimentInfo(this.experimentProfile.id, 'endTime', this.experimentProfile.endTime); + } + private createEmptyExperimentProfile(): ExperimentProfile { return { id: getExperimentId(), diff --git a/ts/nni_manager/core/test/experimentManager.test.ts b/ts/nni_manager/core/test/experimentManager.test.ts new file mode 100644 index 0000000000..eb7b5c44d6 --- /dev/null +++ b/ts/nni_manager/core/test/experimentManager.test.ts @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +'use strict'; + +import { assert, expect } from 'chai'; +import * as fs from 'fs'; +import { Container, Scope } from 'typescript-ioc'; + +import * as component from '../../common/component'; +import { cleanupUnitTest, prepareUnitTest } from '../../common/utils'; +import { ExperimentManager } from '../../common/experimentManager'; +import { NNIExperimentsManager } from '../nniExperimentsManager'; + + +describe('Unit test for experiment manager', function () { + let experimentManager: NNIExperimentsManager; + const mockedInfo = { + "test": { + "port": 8080, + "startTime": 1605246730756, + "endTime": "N/A", + "status": "INITIALIZED", + "platform": "local", + "experimentName": "testExp", + "tag": [], "pid": 11111, + "webuiUrl": [], + "logDir": null + } + } + + before(() => { + prepareUnitTest(); + fs.writeFileSync('.experiment.test', JSON.stringify(mockedInfo)); + Container.bind(ExperimentManager).to(NNIExperimentsManager).scope(Scope.Singleton); + experimentManager = component.get(NNIExperimentsManager); + experimentManager.setExperimentPath('.experiment.test'); + }); + + after(() => { + if (fs.existsSync('.experiment.test')) { + fs.unlinkSync('.experiment.test'); + } + cleanupUnitTest(); + }); + + it('test getExperimentsInfo', () => { + return experimentManager.getExperimentsInfo().then(function (experimentsInfo: {[key: string]: any}) { + new Array(experimentsInfo) + for (let idx in experimentsInfo) { + if (experimentsInfo[idx]['id'] === 'test') { + expect(experimentsInfo[idx]['status']).to.be.oneOf(['STOPPED', 'ERROR']); + break; + } + } + }).catch((error) => { + assert.fail(error); + }) + }); +}); diff --git a/ts/nni_manager/core/test/nnimanager.test.ts b/ts/nni_manager/core/test/nnimanager.test.ts index 2e35f7b4fc..7e0add757c 100644 --- a/ts/nni_manager/core/test/nnimanager.test.ts +++ b/ts/nni_manager/core/test/nnimanager.test.ts @@ -3,6 +3,7 @@ 'use strict'; +import * as fs from 'fs'; import * as os from 'os'; import { assert, expect } from 'chai'; import { Container, Scope } from 'typescript-ioc'; @@ -10,9 +11,10 @@ import { Container, Scope } from 'typescript-ioc'; import * as component from '../../common/component'; import { Database, DataStore } from '../../common/datastore'; import { Manager, ExperimentProfile} from '../../common/manager'; +import { ExperimentManager } from '../../common/experimentManager'; import { TrainingService } from '../../common/trainingService'; import { cleanupUnitTest, prepareUnitTest } from '../../common/utils'; -import { NNIDataStore } from '../nniDataStore'; +import { NNIExperimentsManager } from '../nniExperimentsManager'; import { NNIManager } from '../nnimanager'; import { SqlDB } from '../sqlDatabase'; import { MockedTrainingService } from './mockedTrainingService'; @@ -25,6 +27,7 @@ async function initContainer(): Promise { Container.bind(Manager).to(NNIManager).scope(Scope.Singleton); Container.bind(Database).to(SqlDB).scope(Scope.Singleton); Container.bind(DataStore).to(MockedDataStore).scope(Scope.Singleton); + Container.bind(ExperimentManager).to(NNIExperimentsManager).scope(Scope.Singleton); await component.get(DataStore).init(); } @@ -87,9 +90,26 @@ describe('Unit test for nnimanager', function () { revision: 0 } + let mockedInfo = { + "unittest": { + "port": 8080, + "startTime": 1605246730756, + "endTime": "N/A", + "status": "INITIALIZED", + "platform": "local", + "experimentName": "testExp", + "tag": [], "pid": 11111, + "webuiUrl": [], + "logDir": null + } + } + before(async () => { await initContainer(); + fs.writeFileSync('.experiment.test', JSON.stringify(mockedInfo)); + const experimentsManager: ExperimentManager = component.get(ExperimentManager); + experimentsManager.setExperimentPath('.experiment.test'); nniManager = component.get(Manager); const expId: string = await nniManager.startExperiment(experimentParams); assert.strictEqual(expId, 'unittest'); diff --git a/ts/nni_manager/main.ts b/ts/nni_manager/main.ts index f00c757425..e9f1d5b68a 100644 --- a/ts/nni_manager/main.ts +++ b/ts/nni_manager/main.ts @@ -12,11 +12,13 @@ import { Database, DataStore } from './common/datastore'; import { setExperimentStartupInfo } from './common/experimentStartupInfo'; import { getLogger, Logger, logLevelNameMap } from './common/log'; import { Manager, ExperimentStartUpMode } from './common/manager'; +import { ExperimentManager } from './common/experimentManager'; import { TrainingService } from './common/trainingService'; -import { getLogDir, mkDirP, parseArg, uniqueString } from './common/utils'; +import { getLogDir, mkDirP, parseArg } from './common/utils'; import { NNIDataStore } from './core/nniDataStore'; import { NNIManager } from './core/nnimanager'; import { SqlDB } from './core/sqlDatabase'; +import { NNIExperimentsManager } from './core/nniExperimentsManager'; import { NNIRestServer } from './rest_server/nniRestServer'; import { FrameworkControllerTrainingService } from './training_service/kubernetes/frameworkcontroller/frameworkcontrollerTrainingService'; import { AdlTrainingService } from './training_service/kubernetes/adl/adlTrainingService'; @@ -27,11 +29,10 @@ import { PAIYarnTrainingService } from './training_service/pai/paiYarn/paiYarnTr import { DLTSTrainingService } from './training_service/dlts/dltsTrainingService'; function initStartupInfo( - startExpMode: string, resumeExperimentId: string, basePort: number, platform: string, + startExpMode: string, experimentId: string, basePort: number, platform: string, logDirectory: string, experimentLogLevel: string, readonly: boolean): void { const createNew: boolean = (startExpMode === ExperimentStartUpMode.NEW); - const expId: string = createNew ? uniqueString(8) : resumeExperimentId; - setExperimentStartupInfo(createNew, expId, basePort, platform, logDirectory, experimentLogLevel, readonly); + setExperimentStartupInfo(createNew, experimentId, basePort, platform, logDirectory, experimentLogLevel, readonly); } async function initContainer(foreground: boolean, platformMode: string, logFileName?: string): Promise { @@ -83,6 +84,9 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN Container.bind(DataStore) .to(NNIDataStore) .scope(Scope.Singleton); + Container.bind(ExperimentManager) + .to(NNIExperimentsManager) + .scope(Scope.Singleton); const DEFAULT_LOGFILE: string = path.join(getLogDir(), 'nnimanager.log'); if (foreground) { logFileName = undefined; @@ -133,7 +137,7 @@ if (![ExperimentStartUpMode.NEW, ExperimentStartUpMode.RESUME].includes(startMod } const experimentId: string = parseArg(['--experiment_id', '-id']); -if ((startMode === ExperimentStartUpMode.RESUME) && experimentId.trim().length < 1) { +if (experimentId.trim().length < 1) { console.log(`FATAL: cannot resume the experiment, invalid experiment_id: ${experimentId}`); usage(); process.exit(1); @@ -185,6 +189,8 @@ async function cleanUp(): Promise { try { const nniManager: Manager = component.get(Manager); await nniManager.stopExperiment(); + const experimentManager: ExperimentManager = component.get(ExperimentManager); + await experimentManager.stop(); const ds: DataStore = component.get(DataStore); await ds.close(); const restServer: NNIRestServer = component.get(NNIRestServer); diff --git a/ts/nni_manager/package.json b/ts/nni_manager/package.json index aabb58e6d9..b84ff98631 100644 --- a/ts/nni_manager/package.json +++ b/ts/nni_manager/package.json @@ -18,6 +18,7 @@ "ignore": "^5.1.4", "js-base64": "^2.4.9", "kubernetes-client": "^6.5.0", + "lockfile": "^1.0.4", "python-shell": "^2.0.1", "rx": "^4.1.0", "sqlite3": "^5.0.0", @@ -39,6 +40,7 @@ "@types/glob": "^7.1.1", "@types/js-base64": "^2.3.1", "@types/js-yaml": "^3.12.5", + "@types/lockfile": "^1.0.0", "@types/mocha": "^8.0.3", "@types/node": "10.12.18", "@types/request": "^2.47.1", diff --git a/ts/nni_manager/rest_server/restHandler.ts b/ts/nni_manager/rest_server/restHandler.ts index 2b1cf89c58..d619e73d2c 100644 --- a/ts/nni_manager/rest_server/restHandler.ts +++ b/ts/nni_manager/rest_server/restHandler.ts @@ -12,20 +12,22 @@ import { NNIError, NNIErrorNames } from '../common/errors'; import { isNewExperiment, isReadonly } from '../common/experimentStartupInfo'; import { getLogger, Logger } from '../common/log'; import { ExperimentProfile, Manager, TrialJobStatistics } from '../common/manager'; +import { ExperimentManager } from '../common/experimentManager'; import { ValidationSchemas } from './restValidationSchemas'; import { NNIRestServer } from './nniRestServer'; import { getVersion } from '../common/utils'; -import { NNIManager } from "../core/nnimanager"; const expressJoi = require('express-joi-validator'); class NNIRestHandler { private restServer: NNIRestServer; - private nniManager: NNIManager; + private nniManager: Manager; + private experimentsManager: ExperimentManager; private log: Logger; constructor(rs: NNIRestServer) { this.nniManager = component.get(Manager); + this.experimentsManager = component.get(ExperimentManager); this.restServer = rs; this.log = getLogger(); } @@ -61,6 +63,7 @@ class NNIRestHandler { this.getLatestMetricData(router); this.getTrialLog(router); this.exportData(router); + this.getExperimentsInfo(router); // Express-joi-validator configuration router.use((err: any, _req: Request, res: Response, _next: any) => { @@ -306,6 +309,16 @@ class NNIRestHandler { }); } + private getExperimentsInfo(router: Router): void { + router.get('/experiments-info', (req: Request, res: Response) => { + this.experimentsManager.getExperimentsInfo().then((experimentInfo: JSON) => { + res.send(JSON.stringify(experimentInfo)); + }).catch((err: Error) => { + this.handleError(err, res); + }); + }); + } + private setErrorPathForFailedJob(jobInfo: TrialJobInfo): TrialJobInfo { if (jobInfo === undefined || jobInfo.status !== 'FAILED' || jobInfo.logPath === undefined) { return jobInfo; diff --git a/ts/nni_manager/rest_server/test/mockedExperimentManager.ts b/ts/nni_manager/rest_server/test/mockedExperimentManager.ts new file mode 100644 index 0000000000..b02c6584c5 --- /dev/null +++ b/ts/nni_manager/rest_server/test/mockedExperimentManager.ts @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +'use strict'; + +import { ExperimentManager } from '../../common/experimentManager'; +import { Provider } from 'typescript-ioc'; + +export const testExperimentManagerProvider: Provider = { + get: (): ExperimentManager => { return new mockedeExperimentManager(); } +}; + +export class mockedeExperimentManager extends ExperimentManager { + public getExperimentsInfo(): Promise { + const expInfo = JSON.parse(JSON.stringify({ + "test": { + "port": 8080, + "startTime": 1605246730756, + "endTime": "N/A", + "status": "RUNNING", + "platform": "local", + "experimentName": "testExp", + "tag": [], "pid": 11111, + "webuiUrl": [], + "logDir": null + } + })); + return new Promise((resolve, reject) => { + resolve(expInfo); + }); + } + + public setExperimentPath(newPath: string): void { + return + } + + public setExperimentInfo(experimentId: string, key: string, value: any): void { + return + } + + public stop(): Promise { + return new Promise(()=>{}); + } +} diff --git a/ts/nni_manager/rest_server/test/restserver.test.ts b/ts/nni_manager/rest_server/test/restserver.test.ts index 1f1a6b9132..cc83131b9d 100644 --- a/ts/nni_manager/rest_server/test/restserver.test.ts +++ b/ts/nni_manager/rest_server/test/restserver.test.ts @@ -10,12 +10,14 @@ import { Container } from 'typescript-ioc'; import * as component from '../../common/component'; import { DataStore } from '../../common/datastore'; import { ExperimentProfile, Manager } from '../../common/manager'; +import { ExperimentManager } from '../../common/experimentManager' import { TrainingService } from '../../common/trainingService'; import { cleanupUnitTest, prepareUnitTest } from '../../common/utils'; import { MockedDataStore } from '../../core/test/mockedDatastore'; import { MockedTrainingService } from '../../core/test/mockedTrainingService'; import { NNIRestServer } from '../nniRestServer'; import { testManagerProvider } from './mockedNNIManager'; +import { testExperimentManagerProvider } from './mockedExperimentManager'; describe('Unit test for rest server', () => { @@ -26,6 +28,7 @@ describe('Unit test for rest server', () => { Container.bind(Manager).provider(testManagerProvider); Container.bind(DataStore).to(MockedDataStore); Container.bind(TrainingService).to(MockedTrainingService); + Container.bind(ExperimentManager).provider(testExperimentManagerProvider) const restServer: NNIRestServer = component.get(NNIRestServer); restServer.start().then(() => { ROOT_URL = `${restServer.endPoint}/api/v1/nni`; @@ -84,6 +87,16 @@ describe('Unit test for rest server', () => { }); }); + it('Test GET experiments-info', (done: Mocha.Done) => { + request.get(`${ROOT_URL}/experiments-info`, (err: Error, res: request.Response) => { + expect(res.statusCode).to.equal(200); + if (err) { + assert.fail(err.message); + } + done(); + }); + }); + it('Test change concurrent-trial-jobs', (done: Mocha.Done) => { request.get(`${ROOT_URL}/experiment`, (err: Error, res: request.Response, body: any) => { if (err) {