Skip to content

Commit

Permalink
Merge pull request #39 from microsoft/master
Browse files Browse the repository at this point in the history
pull code
  • Loading branch information
chicm-ms authored Oct 30, 2019
2 parents 2175cef + 7c4b8c0 commit 2ccbfbb
Show file tree
Hide file tree
Showing 53 changed files with 478 additions and 478 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,6 @@ build
*.egg-info

.vscode

# In case you place source code in ~/nni/
/experiments
8 changes: 3 additions & 5 deletions src/sdk/pynni/nni/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,14 @@
# ==================================================================================================


# pylint: disable=wildcard-import

from .trial import *
from .smartparam import *
from .nas_utils import training_update

class NoMoreTrialError(Exception):
def __init__(self,ErrorInfo):
def __init__(self, ErrorInfo):
super().__init__(self)
self.errorinfo=ErrorInfo
self.errorinfo = ErrorInfo

def __str__(self):
return self.errorinfo
return self.errorinfo
112 changes: 63 additions & 49 deletions src/sdk/pynni/nni/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@
import json
import importlib

from .common import enable_multi_thread, enable_multi_phase
from .constants import ModuleName, ClassName, ClassArgs, AdvisorModuleName, AdvisorClassName
from nni.common import enable_multi_thread, enable_multi_phase
from nni.msg_dispatcher import MsgDispatcher
from .msg_dispatcher import MsgDispatcher

logger = logging.getLogger('nni.main')
logger.debug('START')

Expand All @@ -44,7 +45,7 @@ def augment_classargs(input_class_args, classname):
input_class_args[key] = value
return input_class_args

def create_builtin_class_instance(classname, jsonstr_args, is_advisor = False):
def create_builtin_class_instance(classname, jsonstr_args, is_advisor=False):
if is_advisor:
if classname not in AdvisorModuleName or \
importlib.util.find_spec(AdvisorModuleName[classname]) is None:
Expand Down Expand Up @@ -130,55 +131,15 @@ def main():

if args.advisor_class_name:
# advisor is enabled and starts to run
if args.advisor_class_name in AdvisorModuleName:
dispatcher = create_builtin_class_instance(
args.advisor_class_name,
args.advisor_args, True)
else:
dispatcher = create_customized_class_instance(
args.advisor_directory,
args.advisor_class_filename,
args.advisor_class_name,
args.advisor_args)
if dispatcher is None:
raise AssertionError('Failed to create Advisor instance')
try:
dispatcher.run()
except Exception as exception:
logger.exception(exception)
raise
_run_advisor(args)

else:
# tuner (and assessor) is enabled and starts to run
tuner = None
assessor = None
if args.tuner_class_name in ModuleName:
tuner = create_builtin_class_instance(
args.tuner_class_name,
args.tuner_args)
else:
tuner = create_customized_class_instance(
args.tuner_directory,
args.tuner_class_filename,
args.tuner_class_name,
args.tuner_args)

if tuner is None:
raise AssertionError('Failed to create Tuner instance')

tuner = _create_tuner(args)
if args.assessor_class_name:
if args.assessor_class_name in ModuleName:
assessor = create_builtin_class_instance(
args.assessor_class_name,
args.assessor_args)
else:
assessor = create_customized_class_instance(
args.assessor_directory,
args.assessor_class_filename,
args.assessor_class_name,
args.assessor_args)
if assessor is None:
raise AssertionError('Failed to create Assessor instance')

assessor = _create_assessor(args)
else:
assessor = None
dispatcher = MsgDispatcher(tuner, assessor)

try:
Expand All @@ -193,6 +154,59 @@ def main():
assessor._on_error()
raise


def _run_advisor(args):
if args.advisor_class_name in AdvisorModuleName:
dispatcher = create_builtin_class_instance(
args.advisor_class_name,
args.advisor_args, True)
else:
dispatcher = create_customized_class_instance(
args.advisor_directory,
args.advisor_class_filename,
args.advisor_class_name,
args.advisor_args)
if dispatcher is None:
raise AssertionError('Failed to create Advisor instance')
try:
dispatcher.run()
except Exception as exception:
logger.exception(exception)
raise


def _create_tuner(args):
if args.tuner_class_name in ModuleName:
tuner = create_builtin_class_instance(
args.tuner_class_name,
args.tuner_args)
else:
tuner = create_customized_class_instance(
args.tuner_directory,
args.tuner_class_filename,
args.tuner_class_name,
args.tuner_args)
if tuner is None:
raise AssertionError('Failed to create Tuner instance')
return tuner


def _create_assessor(args):
if args.assessor_class_name in ModuleName:
assessor = create_builtin_class_instance(
args.assessor_class_name,
args.assessor_args)
else:
assessor = create_customized_class_instance(
args.assessor_directory,
args.assessor_class_filename,
args.assessor_class_name,
args.assessor_args)
if assessor is None:
raise AssertionError('Failed to create Assessor instance')
return assessor


if __name__ == '__main__':
try:
main()
Expand Down
6 changes: 2 additions & 4 deletions src/sdk/pynni/nni/assessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class AssessResult(Enum):
Bad = False

class Assessor(Recoverable):
# pylint: disable=no-self-use,unused-argument

def assess_trial(self, trial_job_id, trial_history):
"""Determines whether a trial should be killed. Must override.
Expand All @@ -46,21 +45,20 @@ def trial_end(self, trial_job_id, success):
trial_job_id: identifier of the trial (str).
success: True if the trial successfully completed; False if failed or terminated.
"""
pass

def load_checkpoint(self):
"""Load the checkpoint of assessr.
path: checkpoint directory for assessor
"""
checkpoin_path = self.get_checkpoint_path()
_logger.info('Load checkpoint ignored by assessor, checkpoint path: %s' % checkpoin_path)
_logger.info('Load checkpoint ignored by assessor, checkpoint path: %s', checkpoin_path)

def save_checkpoint(self):
"""Save the checkpoint of assessor.
path: checkpoint directory for assessor
"""
checkpoin_path = self.get_checkpoint_path()
_logger.info('Save checkpoint ignored by assessor, checkpoint path: %s' % checkpoin_path)
_logger.info('Save checkpoint ignored by assessor, checkpoint path: %s', checkpoin_path)

def _on_exit(self):
pass
Expand Down
2 changes: 1 addition & 1 deletion src/sdk/pynni/nni/batch_tuner/batch_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def import_data(self, data):
data:
a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
if len(self.values) == 0:
if not self.values:
logger.info("Search space has not been initialized, skip this data import")
return

Expand Down
22 changes: 12 additions & 10 deletions src/sdk/pynni/nni/bohb_advisor/bohb_advisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def create_parameter_id():
int
parameter id
"""
global _next_parameter_id # pylint: disable=global-statement
global _next_parameter_id
_next_parameter_id += 1
return _next_parameter_id - 1

Expand Down Expand Up @@ -80,7 +80,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=-
return params_id


class Bracket(object):
class Bracket:
"""
A bracket in BOHB, all the information of a bracket is managed by
an instance of this class.
Expand All @@ -98,7 +98,7 @@ class Bracket(object):
max_budget : float
The largest budget to consider. Needs to be larger than min_budget!
The budgets will be geometrically distributed
:math:`a^2 + b^2 = c^2 \sim \eta^k` for :math:`k\in [0, 1, ... , num\_subsets - 1]`.
:math:`a^2 + b^2 = c^2 \\sim \\eta^k` for :math:`k\\in [0, 1, ... , num\\_subsets - 1]`.
optimize_mode: str
optimize mode, 'maximize' or 'minimize'
"""
Expand Down Expand Up @@ -169,7 +169,7 @@ def inform_trial_end(self, i):
If we have generated new trials after this trial end, we will return a new trial parameters.
Otherwise, we will return None.
"""
global _KEY # pylint: disable=global-statement
global _KEY
self.num_finished_configs[i] += 1
logger.debug('bracket id: %d, round: %d %d, finished: %d, all: %d',
self.s, self.i, i, self.num_finished_configs[i], self.num_configs_to_run[i])
Expand Down Expand Up @@ -377,8 +377,10 @@ def generate_new_bracket(self):
if self.curr_s < 0:
logger.info("s < 0, Finish this round of Hyperband in BOHB. Generate new round")
self.curr_s = self.s_max
self.brackets[self.curr_s] = Bracket(s=self.curr_s, s_max=self.s_max, eta=self.eta,
max_budget=self.max_budget, optimize_mode=self.optimize_mode)
self.brackets[self.curr_s] = Bracket(
s=self.curr_s, s_max=self.s_max, eta=self.eta,
max_budget=self.max_budget, optimize_mode=self.optimize_mode
)
next_n, next_r = self.brackets[self.curr_s].get_n_r()
logger.debug(
'new SuccessiveHalving iteration, next_n=%d, next_r=%d', next_n, next_r)
Expand Down Expand Up @@ -599,7 +601,7 @@ def handle_report_metric_data(self, data):
logger.debug('bracket id = %s, metrics value = %s, type = %s', s, value, data['type'])
s = int(s)

# add <trial_job_id, parameter_id> to self.job_id_para_id_map here,
# add <trial_job_id, parameter_id> to self.job_id_para_id_map here,
# because when the first parameter_id is created, trial_job_id is not known yet.
if data['trial_job_id'] in self.job_id_para_id_map:
assert self.job_id_para_id_map[data['trial_job_id']] == data['parameter_id']
Expand Down Expand Up @@ -643,14 +645,14 @@ def handle_import_data(self, data):
"""
_completed_num = 0
for trial_info in data:
logger.info("Importing data, current processing progress %s / %s" %(_completed_num, len(data)))
logger.info("Importing data, current processing progress %s / %s", _completed_num, len(data))
_completed_num += 1
assert "parameter" in trial_info
_params = trial_info["parameter"]
assert "value" in trial_info
_value = trial_info['value']
if not _value:
logger.info("Useless trial data, value is %s, skip this trial data." %_value)
logger.info("Useless trial data, value is %s, skip this trial data.", _value)
continue
budget_exist_flag = False
barely_params = dict()
Expand All @@ -662,7 +664,7 @@ def handle_import_data(self, data):
barely_params[keys] = _params[keys]
if not budget_exist_flag:
_budget = self.max_budget
logger.info("Set \"TRIAL_BUDGET\" value to %s (max budget)" %self.max_budget)
logger.info("Set \"TRIAL_BUDGET\" value to %s (max budget)", self.max_budget)
if self.optimize_mode is OptimizeMode.Maximize:
reward = -_value
else:
Expand Down
Loading

0 comments on commit 2ccbfbb

Please sign in to comment.