Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

improve PBT tuner #2357

Merged
merged 10 commits into from
May 11, 2020
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 108 additions & 41 deletions src/sdk/pynni/nni/pbt_tuner/pbt_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,16 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factor, resample_probabi
top_hyper_parameters = top_trial_info.hyper_parameters
hyper_parameters = copy.deepcopy(top_hyper_parameters)
random_state = np.random.RandomState()
hyper_parameters['load_checkpoint_dir'] = hyper_parameters['save_checkpoint_dir']
chicm-ms marked this conversation as resolved.
Show resolved Hide resolved
hyper_parameters['save_checkpoint_dir'] = os.path.join(bot_checkpoint_dir, str(epoch))
for key in hyper_parameters.keys():
hyper_parameter = hyper_parameters[key]
if key == 'load_checkpoint_dir':
hyper_parameters[key] = hyper_parameters['save_checkpoint_dir']
continue
elif key == 'save_checkpoint_dir':
hyper_parameters[key] = os.path.join(bot_checkpoint_dir, str(epoch))
if key == 'load_checkpoint_dir' or key == 'save_checkpoint_dir':
continue
elif search_space[key]["_type"] == "choice":
choices = search_space[key]["_value"]
ub, uv = len(choices) - 1, choices.index(hyper_parameter["_value"]) + 1
lb, lv = 0, choices.index(hyper_parameter["_value"]) - 1
ub, uv = len(choices) - 1, choices.index(hyper_parameter) + 1
lb, lv = 0, choices.index(hyper_parameter) - 1
elif search_space[key]["_type"] == "randint":
lb, ub = search_space[key]["_value"][:2]
ub -= 1
Expand Down Expand Up @@ -132,10 +130,11 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factor, resample_probabi
else:
logger.warning("Illegal type to perturb: %s", search_space[key]["_type"])
continue

if search_space[key]["_type"] == "choice":
idx = perturbation(search_space[key]["_type"], search_space[key]["_value"],
resample_probability, uv, ub, lv, lb, random_state)
hyper_parameters[key] = {'_index': idx, '_value': choices[idx]}
hyper_parameters[key] = choices[idx]
else:
hyper_parameters[key] = perturbation(search_space[key]["_type"], search_space[key]["_value"],
resample_probability, uv, ub, lv, lb, random_state)
Expand Down Expand Up @@ -231,6 +230,7 @@ def update_search_space(self, search_space):
for i in range(self.population_size):
hyper_parameters = json2parameter(
self.searchspace_json, is_rand, self.random_state)
hyper_parameters = split_index(hyper_parameters)
checkpoint_dir = os.path.join(self.all_checkpoint_dir, str(i))
hyper_parameters['load_checkpoint_dir'] = os.path.join(checkpoint_dir, str(self.epoch))
hyper_parameters['save_checkpoint_dir'] = os.path.join(checkpoint_dir, str(self.epoch))
Expand Down Expand Up @@ -294,7 +294,42 @@ def generate_parameters(self, parameter_id, **kwargs):
trial_info.parameter_id = parameter_id
self.running[parameter_id] = trial_info
logger.info('Generate parameter : %s', trial_info.hyper_parameters)
return split_index(trial_info.hyper_parameters)
return trial_info.hyper_parameters

def _proceed_next_epoch(self):
"""
"""
logger.info('Proceeding to next epoch')
self.epoch += 1
self.population = []
self.pos = -1
self.running = {}
#exploit and explore
reverse = True if self.optimize_mode == OptimizeMode.Maximize else False
self.finished = sorted(self.finished, key=lambda x: x.score, reverse=reverse)
cutoff = int(np.ceil(self.fraction * len(self.finished)))
tops = self.finished[:cutoff]
bottoms = self.finished[self.finished_trials - cutoff:]
for bottom in bottoms:
top = np.random.choice(tops)
exploit_and_explore(bottom, top, self.factor, self.resample_probability, self.epoch, self.searchspace_json)
for trial in self.finished:
if trial not in bottoms:
trial.clean_id()
trial.hyper_parameters['load_checkpoint_dir'] = trial.hyper_parameters['save_checkpoint_dir']
trial.hyper_parameters['save_checkpoint_dir'] = os.path.join(trial.checkpoint_dir, str(self.epoch))
self.finished_trials = 0
for _ in range(self.population_size):
trial_info = self.finished.pop()
self.population.append(trial_info)
while self.credit > 0 and self.pos + 1 < len(self.population):
self.credit -= 1
self.pos += 1
parameter_id = self.param_ids.pop()
trial_info = self.population[self.pos]
trial_info.parameter_id = parameter_id
self.running[parameter_id] = trial_info
self.send_trial_callback(parameter_id, trial_info.hyper_parameters)

def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
"""
Expand All @@ -312,43 +347,75 @@ def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
"""
logger.info('Get one trial result, id = %d, value = %s', parameter_id, value)
value = extract_scalar_reward(value)
trial_info = self.running.pop(parameter_id, None)
trial_info.score = value
self.finished.append(trial_info)
self.finished_trials += 1
if self.finished_trials == self.population_size:
self._proceed_next_epoch()

def trial_end(self, parameter_id, success, **kwargs):
"""
"""
if success:
return
if self.optimize_mode == OptimizeMode.Minimize:
value = -value
value = float('inf')
else:
value = float('-inf')
trial_info = self.running.pop(parameter_id, None)
trial_info.score = value
self.finished.append(trial_info)
self.finished_trials += 1
if self.finished_trials == self.population_size:
logger.info('Proceeding to next epoch')
self.epoch += 1
self.population = []
self.pos = -1
self.running = {}
#exploit and explore
self.finished = sorted(self.finished, key=lambda x: x.score, reverse=True)
cutoff = int(np.ceil(self.fraction * len(self.finished)))
tops = self.finished[:cutoff]
bottoms = self.finished[self.finished_trials - cutoff:]
for bottom in bottoms:
top = np.random.choice(tops)
exploit_and_explore(bottom, top, self.factor, self.resample_probability, self.epoch, self.searchspace_json)
for trial in self.finished:
if trial not in bottoms:
trial.clean_id()
trial.hyper_parameters['load_checkpoint_dir'] = trial.hyper_parameters['save_checkpoint_dir']
trial.hyper_parameters['save_checkpoint_dir'] = os.path.join(trial.checkpoint_dir, str(self.epoch))
self.finished_trials = 0
for _ in range(self.population_size):
trial_info = self.finished.pop()
self.population.append(trial_info)
while self.credit > 0 and self.pos + 1 < len(self.population):
self.credit -= 1
self.pos += 1
parameter_id = self.param_ids.pop()
trial_info = self.population[self.pos]
trial_info.parameter_id = parameter_id
self.running[parameter_id] = trial_info
self.send_trial_callback(parameter_id, split_index(trial_info.hyper_parameters))
self._proceed_next_epoch()

def import_data(self, data):
pass
"""
"""
if self.running:
logger.warning("Do not support importing data in the middle of experiment")
return
# the following is for experiment resume
_completed_num = 0
epoch_data_dict = {} # key: epoch num, value: list of configurations
for trial_info in data:
logger.info("Importing data, current processing progress %s / %s", _completed_num, len(data))
# simply validate data format
assert "parameter" in trial_info
_params = trial_info["parameter"]
assert "value" in trial_info
_value = trial_info['value']
if not _value:
# assign fake value for failed trials
logger.info("Useless trial data, value is %s, skip this trial data.", _value)
_value = float('inf') if self.optimize_mode == OptimizeMode.Minimize else float('-inf')
_value = extract_scalar_reward(_value)
epoch_num = int(os.path.basename(_params['save_checkpoint_dir']))
if epoch_num not in epoch_data_dict:
epoch_data_dict[epoch_num] = []
epoch_data_dict[epoch_num].append((_params, _value))
if not epoch_data_dict:
return
# figure out start epoch for resume
max_epoch_num = max(epoch_data_dict, key=int)
if len(epoch_data_dict[max_epoch_num]) < self.population_size:
max_epoch_num = max_epoch_num - 1
# If there is no a single complete round, no data to import, start from scratch
if max_epoch_num < 0:
return
assert len(epoch_data_dict[max_epoch_num]) == self.population_size
# check existence of trial save checkpoint dir
for params, _ in epoch_data_dict[max_epoch_num]:
if not os.path.isdir(params['save_checkpoint_dir']):
logger.warning("save_checkpoint_dir %s does not exist, data will not be resumed", params['save_checkpoint_dir'])
return
# resume data
self.epoch = max_epoch_num
self.finished_trials = self.population_size
for params, value in epoch_data_dict[max_epoch_num]:
checkpoint_dir = os.path.dirname(params['save_checkpoint_dir'])
self.finished.append(TrialInfo(checkpoint_dir=checkpoint_dir, hyper_parameters=params, score=value))
self._proceed_next_epoch()
logger.info("Successfully import data to PBT tuner, total data: %d, imported data: %d.", len(data), _completed_num)
logger.info("Start from epoch %d ...", self.epoch)