Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pull code #39

Merged
merged 2 commits into from
Oct 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,6 @@ build
*.egg-info

.vscode

# In case you place source code in ~/nni/
/experiments
8 changes: 3 additions & 5 deletions src/sdk/pynni/nni/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,14 @@
# ==================================================================================================


# pylint: disable=wildcard-import

from .trial import *
from .smartparam import *
from .nas_utils import training_update

class NoMoreTrialError(Exception):
def __init__(self,ErrorInfo):
def __init__(self, ErrorInfo):
super().__init__(self)
self.errorinfo=ErrorInfo
self.errorinfo = ErrorInfo

def __str__(self):
return self.errorinfo
return self.errorinfo
112 changes: 63 additions & 49 deletions src/sdk/pynni/nni/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@
import json
import importlib

from .common import enable_multi_thread, enable_multi_phase
from .constants import ModuleName, ClassName, ClassArgs, AdvisorModuleName, AdvisorClassName
from nni.common import enable_multi_thread, enable_multi_phase
from nni.msg_dispatcher import MsgDispatcher
from .msg_dispatcher import MsgDispatcher

logger = logging.getLogger('nni.main')
logger.debug('START')

Expand All @@ -44,7 +45,7 @@ def augment_classargs(input_class_args, classname):
input_class_args[key] = value
return input_class_args

def create_builtin_class_instance(classname, jsonstr_args, is_advisor = False):
def create_builtin_class_instance(classname, jsonstr_args, is_advisor=False):
if is_advisor:
if classname not in AdvisorModuleName or \
importlib.util.find_spec(AdvisorModuleName[classname]) is None:
Expand Down Expand Up @@ -130,55 +131,15 @@ def main():

if args.advisor_class_name:
# advisor is enabled and starts to run
if args.advisor_class_name in AdvisorModuleName:
dispatcher = create_builtin_class_instance(
args.advisor_class_name,
args.advisor_args, True)
else:
dispatcher = create_customized_class_instance(
args.advisor_directory,
args.advisor_class_filename,
args.advisor_class_name,
args.advisor_args)
if dispatcher is None:
raise AssertionError('Failed to create Advisor instance')
try:
dispatcher.run()
except Exception as exception:
logger.exception(exception)
raise
_run_advisor(args)

else:
# tuner (and assessor) is enabled and starts to run
tuner = None
assessor = None
if args.tuner_class_name in ModuleName:
tuner = create_builtin_class_instance(
args.tuner_class_name,
args.tuner_args)
else:
tuner = create_customized_class_instance(
args.tuner_directory,
args.tuner_class_filename,
args.tuner_class_name,
args.tuner_args)

if tuner is None:
raise AssertionError('Failed to create Tuner instance')

tuner = _create_tuner(args)
if args.assessor_class_name:
if args.assessor_class_name in ModuleName:
assessor = create_builtin_class_instance(
args.assessor_class_name,
args.assessor_args)
else:
assessor = create_customized_class_instance(
args.assessor_directory,
args.assessor_class_filename,
args.assessor_class_name,
args.assessor_args)
if assessor is None:
raise AssertionError('Failed to create Assessor instance')

assessor = _create_assessor(args)
else:
assessor = None
dispatcher = MsgDispatcher(tuner, assessor)

try:
Expand All @@ -193,6 +154,59 @@ def main():
assessor._on_error()
raise


def _run_advisor(args):
if args.advisor_class_name in AdvisorModuleName:
dispatcher = create_builtin_class_instance(
args.advisor_class_name,
args.advisor_args, True)
else:
dispatcher = create_customized_class_instance(
args.advisor_directory,
args.advisor_class_filename,
args.advisor_class_name,
args.advisor_args)
if dispatcher is None:
raise AssertionError('Failed to create Advisor instance')
try:
dispatcher.run()
except Exception as exception:
logger.exception(exception)
raise


def _create_tuner(args):
if args.tuner_class_name in ModuleName:
tuner = create_builtin_class_instance(
args.tuner_class_name,
args.tuner_args)
else:
tuner = create_customized_class_instance(
args.tuner_directory,
args.tuner_class_filename,
args.tuner_class_name,
args.tuner_args)
if tuner is None:
raise AssertionError('Failed to create Tuner instance')
return tuner


def _create_assessor(args):
if args.assessor_class_name in ModuleName:
assessor = create_builtin_class_instance(
args.assessor_class_name,
args.assessor_args)
else:
assessor = create_customized_class_instance(
args.assessor_directory,
args.assessor_class_filename,
args.assessor_class_name,
args.assessor_args)
if assessor is None:
raise AssertionError('Failed to create Assessor instance')
return assessor


if __name__ == '__main__':
try:
main()
Expand Down
6 changes: 2 additions & 4 deletions src/sdk/pynni/nni/assessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class AssessResult(Enum):
Bad = False

class Assessor(Recoverable):
# pylint: disable=no-self-use,unused-argument

def assess_trial(self, trial_job_id, trial_history):
"""Determines whether a trial should be killed. Must override.
Expand All @@ -46,21 +45,20 @@ def trial_end(self, trial_job_id, success):
trial_job_id: identifier of the trial (str).
success: True if the trial successfully completed; False if failed or terminated.
"""
pass

def load_checkpoint(self):
"""Load the checkpoint of assessr.
path: checkpoint directory for assessor
"""
checkpoin_path = self.get_checkpoint_path()
_logger.info('Load checkpoint ignored by assessor, checkpoint path: %s' % checkpoin_path)
_logger.info('Load checkpoint ignored by assessor, checkpoint path: %s', checkpoin_path)

def save_checkpoint(self):
"""Save the checkpoint of assessor.
path: checkpoint directory for assessor
"""
checkpoin_path = self.get_checkpoint_path()
_logger.info('Save checkpoint ignored by assessor, checkpoint path: %s' % checkpoin_path)
_logger.info('Save checkpoint ignored by assessor, checkpoint path: %s', checkpoin_path)

def _on_exit(self):
pass
Expand Down
2 changes: 1 addition & 1 deletion src/sdk/pynni/nni/batch_tuner/batch_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def import_data(self, data):
data:
a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
if len(self.values) == 0:
if not self.values:
logger.info("Search space has not been initialized, skip this data import")
return

Expand Down
22 changes: 12 additions & 10 deletions src/sdk/pynni/nni/bohb_advisor/bohb_advisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def create_parameter_id():
int
parameter id
"""
global _next_parameter_id # pylint: disable=global-statement
global _next_parameter_id
_next_parameter_id += 1
return _next_parameter_id - 1

Expand Down Expand Up @@ -80,7 +80,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=-
return params_id


class Bracket(object):
class Bracket:
"""
A bracket in BOHB, all the information of a bracket is managed by
an instance of this class.
Expand All @@ -98,7 +98,7 @@ class Bracket(object):
max_budget : float
The largest budget to consider. Needs to be larger than min_budget!
The budgets will be geometrically distributed
:math:`a^2 + b^2 = c^2 \sim \eta^k` for :math:`k\in [0, 1, ... , num\_subsets - 1]`.
:math:`a^2 + b^2 = c^2 \\sim \\eta^k` for :math:`k\\in [0, 1, ... , num\\_subsets - 1]`.
optimize_mode: str
optimize mode, 'maximize' or 'minimize'
"""
Expand Down Expand Up @@ -169,7 +169,7 @@ def inform_trial_end(self, i):
If we have generated new trials after this trial end, we will return a new trial parameters.
Otherwise, we will return None.
"""
global _KEY # pylint: disable=global-statement
global _KEY
self.num_finished_configs[i] += 1
logger.debug('bracket id: %d, round: %d %d, finished: %d, all: %d',
self.s, self.i, i, self.num_finished_configs[i], self.num_configs_to_run[i])
Expand Down Expand Up @@ -377,8 +377,10 @@ def generate_new_bracket(self):
if self.curr_s < 0:
logger.info("s < 0, Finish this round of Hyperband in BOHB. Generate new round")
self.curr_s = self.s_max
self.brackets[self.curr_s] = Bracket(s=self.curr_s, s_max=self.s_max, eta=self.eta,
max_budget=self.max_budget, optimize_mode=self.optimize_mode)
self.brackets[self.curr_s] = Bracket(
s=self.curr_s, s_max=self.s_max, eta=self.eta,
max_budget=self.max_budget, optimize_mode=self.optimize_mode
)
next_n, next_r = self.brackets[self.curr_s].get_n_r()
logger.debug(
'new SuccessiveHalving iteration, next_n=%d, next_r=%d', next_n, next_r)
Expand Down Expand Up @@ -599,7 +601,7 @@ def handle_report_metric_data(self, data):
logger.debug('bracket id = %s, metrics value = %s, type = %s', s, value, data['type'])
s = int(s)

# add <trial_job_id, parameter_id> to self.job_id_para_id_map here,
# add <trial_job_id, parameter_id> to self.job_id_para_id_map here,
# because when the first parameter_id is created, trial_job_id is not known yet.
if data['trial_job_id'] in self.job_id_para_id_map:
assert self.job_id_para_id_map[data['trial_job_id']] == data['parameter_id']
Expand Down Expand Up @@ -643,14 +645,14 @@ def handle_import_data(self, data):
"""
_completed_num = 0
for trial_info in data:
logger.info("Importing data, current processing progress %s / %s" %(_completed_num, len(data)))
logger.info("Importing data, current processing progress %s / %s", _completed_num, len(data))
_completed_num += 1
assert "parameter" in trial_info
_params = trial_info["parameter"]
assert "value" in trial_info
_value = trial_info['value']
if not _value:
logger.info("Useless trial data, value is %s, skip this trial data." %_value)
logger.info("Useless trial data, value is %s, skip this trial data.", _value)
continue
budget_exist_flag = False
barely_params = dict()
Expand All @@ -662,7 +664,7 @@ def handle_import_data(self, data):
barely_params[keys] = _params[keys]
if not budget_exist_flag:
_budget = self.max_budget
logger.info("Set \"TRIAL_BUDGET\" value to %s (max budget)" %self.max_budget)
logger.info("Set \"TRIAL_BUDGET\" value to %s (max budget)", self.max_budget)
if self.optimize_mode is OptimizeMode.Maximize:
reward = -_value
else:
Expand Down
Loading