Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

make assessors support metric data in dict #2121

Merged
merged 3 commits into from
Mar 6, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import datetime
from nni.assessor import Assessor, AssessResult
from nni.utils import extract_scalar_history
from .model_factory import CurveModel

logger = logging.getLogger('curvefitting_Assessor')
Expand Down Expand Up @@ -91,10 +92,11 @@ def assess_trial(self, trial_job_id, trial_history):
Exception
unrecognize exception in curvefitting_assessor
"""
self.trial_history = trial_history
scalar_trial_history = extract_scalar_history(trial_history)
self.trial_history = scalar_trial_history
if not self.set_best_performance:
return AssessResult.Good
curr_step = len(trial_history)
curr_step = len(scalar_trial_history)
if curr_step < self.start_step:
return AssessResult.Good

Expand All @@ -106,7 +108,7 @@ def assess_trial(self, trial_job_id, trial_history):
start_time = datetime.datetime.now()
# Predict the final result
curvemodel = CurveModel(self.target_pos)
predict_y = curvemodel.predict(trial_history)
predict_y = curvemodel.predict(scalar_trial_history)
logger.info('Prediction done. Trial job id = %s. Predict value = %s', trial_job_id, predict_y)
if predict_y is None:
logger.info('wait for more information to predict precisely')
Expand Down
17 changes: 5 additions & 12 deletions src/sdk/pynni/nni/medianstop_assessor/medianstop_assessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import logging
from nni.assessor import Assessor, AssessResult
from nni.utils import extract_scalar_history

logger = logging.getLogger('medianstop_Assessor')

Expand Down Expand Up @@ -91,20 +92,12 @@ def assess_trial(self, trial_job_id, trial_history):
if curr_step < self._start_step:
return AssessResult.Good

try:
num_trial_history = [float(ele) for ele in trial_history]
except (TypeError, ValueError) as error:
logger.warning('incorrect data type or value:')
logger.exception(error)
except Exception as error:
logger.warning('unrecognized exception in medianstop_assessor:')
logger.exception(error)

self._update_data(trial_job_id, num_trial_history)
scalar_trial_history = extract_scalar_history(trial_history)
self._update_data(trial_job_id, scalar_trial_history)
if self._high_better:
best_history = max(trial_history)
best_history = max(scalar_trial_history)
else:
best_history = min(trial_history)
best_history = min(scalar_trial_history)

avg_array = []
for id_ in self._completed_avg_history:
Expand Down
1 change: 1 addition & 0 deletions src/sdk/pynni/nni/msg_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,4 +234,5 @@ def _earlystop_notify_tuner(self, data):
if multi_thread_enabled():
self._handle_final_metric_data(data)
else:
data['value'] = to_json(data['value'])
self.enqueue_command(CommandType.ReportMetricData, data)
31 changes: 30 additions & 1 deletion src/sdk/pynni/nni/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ def extract_scalar_reward(value, scalar_key='default'):
"""
Extract scalar reward from trial result.

Parameters
----------
value : int, float, dict
the reported final metric data
scalar_key : str
the key name that indicates the numeric number

Raises
------
RuntimeError
Expand All @@ -78,6 +85,26 @@ def extract_scalar_reward(value, scalar_key='default'):
return reward


def extract_scalar_history(trial_history, scalar_key='default'):
"""
Extract scalar value from a list of intermediate results.

Parameters
----------
trial_history : list
accumulated intermediate results of a trial
scalar_key : str
the key name that indicates the numeric number

Raises
------
RuntimeError
Incorrect final result: the final result should be float/int,
or a dict which has a key named "default" whose value is float/int.
"""
return [extract_scalar_reward(ele) for ele in trial_history]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the scalar_key should be passed to extract_scalar_reward

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point!



def convert_dict2tuple(value):
"""
convert dict type to tuple to solve unhashable problem.
Expand All @@ -90,7 +117,9 @@ def convert_dict2tuple(value):


def init_dispatcher_logger():
""" Initialize dispatcher logging configuration"""
"""
Initialize dispatcher logging configuration
"""
logger_file_path = 'dispatcher.log'
if dispatcher_env_vars.NNI_LOG_DIRECTORY is not None:
logger_file_path = os.path.join(dispatcher_env_vars.NNI_LOG_DIRECTORY, logger_file_path)
Expand Down