diff --git a/nni/experiment/experiment.py b/nni/experiment/experiment.py index b83bb62b5e..4b76dc0e51 100644 --- a/nni/experiment/experiment.py +++ b/nni/experiment/experiment.py @@ -75,6 +75,7 @@ def __init__(self, config=None, training_service=None): self.id: Optional[str] = None self.port: Optional[int] = None self._proc: Optional[Popen] = None + self.mode = 'new' args = [config, training_service] # deal with overloading if isinstance(args[0], (str, list)): @@ -101,7 +102,10 @@ def start(self, port: int = 8080, debug: bool = False) -> None: """ atexit.register(self.stop) - self.id = management.generate_experiment_id() + if self.mode == 'new': + self.id = management.generate_experiment_id() + else: + self.config = launcher.get_stopped_experiment_config(self.id, self.mode) if self.config.experiment_working_directory is not None: log_dir = Path(self.config.experiment_working_directory, self.id, 'log') @@ -109,7 +113,7 @@ def start(self, port: int = 8080, debug: bool = False) -> None: log_dir = Path.home() / f'nni-experiments/{self.id}/log' nni.runtime.log.start_experiment_log(self.id, log_dir, debug) - self._proc = launcher.start_experiment(self.id, self.config, port, debug) + self._proc = launcher.start_experiment(self.id, self.config, port, debug, mode=self.mode) assert self._proc is not None self.port = port # port will be None if start up failed @@ -189,6 +193,42 @@ def connect(cls, port: int): _logger.info('Connect to port %d success, experiment id is %s, status is %s.', port, experiment.id, status) return experiment + @classmethod + def resume(cls, experiment_id: str, port: int, wait_completion: bool = True, debug: bool = False): + """ + Resume a stopped experiment. + + Parameters + ---------- + experiment_id + The stopped experiment id. + """ + experiment = Experiment() + experiment.mode = 'resume' + if wait_completion: + experiment.run(port, debug) + else: + experiment.start(port, debug) + return experiment + + @classmethod + def view(cls, experiment_id: str, port: int, wait_completion: bool = True, debug: bool = False): + """ + View a stopped experiment. + + Parameters + ---------- + experiment_id + The stopped experiment id. + """ + experiment = Experiment() + experiment.mode = 'view' + if wait_completion: + experiment.run(port, debug) + else: + experiment.start(port, debug) + return experiment + def get_status(self) -> str: """ Return experiment status as a str. diff --git a/nni/experiment/launcher.py b/nni/experiment/launcher.py index 88226cf178..be8b594705 100644 --- a/nni/experiment/launcher.py +++ b/nni/experiment/launcher.py @@ -18,31 +18,35 @@ from .config import ExperimentConfig from .pipe import Pipe from . import rest -from ..tools.nnictl.config_utils import Experiments +from ..tools.nnictl.config_utils import Experiments, Config +from ..tools.nnictl.nnictl_utils import update_experiment _logger = logging.getLogger('nni.experiment') -def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bool) -> Popen: +def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bool, mode: str = 'new') -> Popen: proc = None config.validate(initialized_tuner=False) _ensure_port_idle(port) - if isinstance(config.training_service, list): # hybrid training service - _ensure_port_idle(port + 1, 'Hybrid training service requires an additional port') - elif config.training_service.platform in ['remote', 'openpai', 'kubeflow', 'frameworkcontroller', 'adl']: - _ensure_port_idle(port + 1, f'{config.training_service.platform} requires an additional port') + + if mode != 'view': + if isinstance(config.training_service, list): # hybrid training service + _ensure_port_idle(port + 1, 'Hybrid training service requires an additional port') + elif config.training_service.platform in ['remote', 'openpai', 'kubeflow', 'frameworkcontroller', 'adl']: + _ensure_port_idle(port + 1, f'{config.training_service.platform} requires an additional port') try: _logger.info('Creating experiment, Experiment ID: %s', colorama.Fore.CYAN + exp_id + colorama.Style.RESET_ALL) - start_time, proc = _start_rest_server(config, port, debug, exp_id) + start_time, proc = _start_rest_server(config, port, debug, exp_id, mode=mode) _logger.info('Statring web server...') _check_rest_server(port) platform = 'hybrid' if isinstance(config.training_service, list) else config.training_service.platform _save_experiment_information(exp_id, port, start_time, platform, config.experiment_name, proc.pid, config.experiment_working_directory) - _logger.info('Setting up...') - rest.post(port, '/experiment', config.json()) + if mode != 'view': + _logger.info('Setting up...') + rest.post(port, '/experiment', config.json()) return proc except Exception as e: @@ -98,7 +102,8 @@ def _ensure_port_idle(port: int, message: Optional[str] = None) -> None: raise RuntimeError(f'Port {port} is not idle {message}') -def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experiment_id: str, pipe_path: str = None) -> Tuple[int, Popen]: +def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experiment_id: str, pipe_path: str = None, + mode: str = 'new') -> Tuple[int, Popen]: if isinstance(config.training_service, list): ts = 'hybrid' else: @@ -110,12 +115,16 @@ def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experim 'port': port, 'mode': ts, 'experiment_id': experiment_id, - 'start_mode': 'new', + 'start_mode': mode, 'log_level': 'debug' if debug else 'info', } if pipe_path is not None: args['dispatcher_pipe'] = pipe_path + if mode == 'view': + args['start_mode'] = 'resume' + args['readonly'] = 'true' + node_dir = Path(nni_node.__path__[0]) node = str(node_dir / ('node.exe' if sys.platform == 'win32' else 'node')) main_js = str(node_dir / 'main.js') @@ -150,3 +159,19 @@ def _check_rest_server(port: int, retry: int = 3) -> None: def _save_experiment_information(experiment_id: str, port: int, start_time: int, platform: str, name: str, pid: int, logDir: str) -> None: experiments_config = Experiments() experiments_config.add_experiment(experiment_id, port, start_time, platform, name, pid=pid, logDir=logDir) + + +def get_stopped_experiment_config(exp_id: str, mode: str) -> None: + update_experiment() + experiments_config = Experiments() + experiments_dict = experiments_config.get_all_experiments() + experiment_metadata = experiments_dict.get(exp_id) + if experiment_metadata is None: + logging.error('Id %s not exist!', exp_id) + return + if experiment_metadata['status'] != 'STOPPED': + logging.error('Only stopped experiments can be %sed!', mode) + return + experiment_config = Config(exp_id, experiment_metadata['logDir']).get_config() + config = ExperimentConfig(**experiment_config) + return config