Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

improve PBT tuner #2357

Merged
merged 10 commits into from
May 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 132 additions & 41 deletions src/sdk/pynni/nni/pbt_tuner/pbt_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,16 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factor, resample_probabi
top_hyper_parameters = top_trial_info.hyper_parameters
hyper_parameters = copy.deepcopy(top_hyper_parameters)
random_state = np.random.RandomState()
hyper_parameters['load_checkpoint_dir'] = hyper_parameters['save_checkpoint_dir']
chicm-ms marked this conversation as resolved.
Show resolved Hide resolved
hyper_parameters['save_checkpoint_dir'] = os.path.join(bot_checkpoint_dir, str(epoch))
for key in hyper_parameters.keys():
hyper_parameter = hyper_parameters[key]
if key == 'load_checkpoint_dir':
hyper_parameters[key] = hyper_parameters['save_checkpoint_dir']
continue
elif key == 'save_checkpoint_dir':
hyper_parameters[key] = os.path.join(bot_checkpoint_dir, str(epoch))
if key == 'load_checkpoint_dir' or key == 'save_checkpoint_dir':
continue
elif search_space[key]["_type"] == "choice":
choices = search_space[key]["_value"]
ub, uv = len(choices) - 1, choices.index(hyper_parameter["_value"]) + 1
lb, lv = 0, choices.index(hyper_parameter["_value"]) - 1
ub, uv = len(choices) - 1, choices.index(hyper_parameter) + 1
lb, lv = 0, choices.index(hyper_parameter) - 1
elif search_space[key]["_type"] == "randint":
lb, ub = search_space[key]["_value"][:2]
ub -= 1
Expand Down Expand Up @@ -132,10 +130,11 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factor, resample_probabi
else:
logger.warning("Illegal type to perturb: %s", search_space[key]["_type"])
continue

if search_space[key]["_type"] == "choice":
idx = perturbation(search_space[key]["_type"], search_space[key]["_value"],
resample_probability, uv, ub, lv, lb, random_state)
hyper_parameters[key] = {'_index': idx, '_value': choices[idx]}
hyper_parameters[key] = choices[idx]
else:
hyper_parameters[key] = perturbation(search_space[key]["_type"], search_space[key]["_value"],
resample_probability, uv, ub, lv, lb, random_state)
Expand Down Expand Up @@ -231,6 +230,7 @@ def update_search_space(self, search_space):
for i in range(self.population_size):
hyper_parameters = json2parameter(
self.searchspace_json, is_rand, self.random_state)
hyper_parameters = split_index(hyper_parameters)
checkpoint_dir = os.path.join(self.all_checkpoint_dir, str(i))
hyper_parameters['load_checkpoint_dir'] = os.path.join(checkpoint_dir, str(self.epoch))
hyper_parameters['save_checkpoint_dir'] = os.path.join(checkpoint_dir, str(self.epoch))
Expand Down Expand Up @@ -294,7 +294,42 @@ def generate_parameters(self, parameter_id, **kwargs):
trial_info.parameter_id = parameter_id
self.running[parameter_id] = trial_info
logger.info('Generate parameter : %s', trial_info.hyper_parameters)
return split_index(trial_info.hyper_parameters)
return trial_info.hyper_parameters

def _proceed_next_epoch(self):
"""
"""
logger.info('Proceeding to next epoch')
self.epoch += 1
self.population = []
self.pos = -1
self.running = {}
#exploit and explore
reverse = True if self.optimize_mode == OptimizeMode.Maximize else False
self.finished = sorted(self.finished, key=lambda x: x.score, reverse=reverse)
cutoff = int(np.ceil(self.fraction * len(self.finished)))
tops = self.finished[:cutoff]
bottoms = self.finished[self.finished_trials - cutoff:]
for bottom in bottoms:
top = np.random.choice(tops)
exploit_and_explore(bottom, top, self.factor, self.resample_probability, self.epoch, self.searchspace_json)
for trial in self.finished:
if trial not in bottoms:
trial.clean_id()
trial.hyper_parameters['load_checkpoint_dir'] = trial.hyper_parameters['save_checkpoint_dir']
trial.hyper_parameters['save_checkpoint_dir'] = os.path.join(trial.checkpoint_dir, str(self.epoch))
self.finished_trials = 0
for _ in range(self.population_size):
trial_info = self.finished.pop()
self.population.append(trial_info)
while self.credit > 0 and self.pos + 1 < len(self.population):
self.credit -= 1
self.pos += 1
parameter_id = self.param_ids.pop()
trial_info = self.population[self.pos]
trial_info.parameter_id = parameter_id
self.running[parameter_id] = trial_info
self.send_trial_callback(parameter_id, trial_info.hyper_parameters)

def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
"""
Expand All @@ -312,43 +347,99 @@ def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
"""
logger.info('Get one trial result, id = %d, value = %s', parameter_id, value)
value = extract_scalar_reward(value)
trial_info = self.running.pop(parameter_id, None)
trial_info.score = value
self.finished.append(trial_info)
self.finished_trials += 1
if self.finished_trials == self.population_size:
self._proceed_next_epoch()

def trial_end(self, parameter_id, success, **kwargs):
"""
Deal with trial failure

Parameters
----------
parameter_id : int
Unique identifier for hyper-parameters used by this trial.
success : bool
True if the trial successfully completed; False if failed or terminated.
**kwargs
Unstable parameters which should be ignored by normal users.
"""
if success:
return
if self.optimize_mode == OptimizeMode.Minimize:
value = -value
value = float('inf')
else:
value = float('-inf')
trial_info = self.running.pop(parameter_id, None)
trial_info.score = value
self.finished.append(trial_info)
self.finished_trials += 1
if self.finished_trials == self.population_size:
logger.info('Proceeding to next epoch')
self.epoch += 1
self.population = []
self.pos = -1
self.running = {}
#exploit and explore
self.finished = sorted(self.finished, key=lambda x: x.score, reverse=True)
cutoff = int(np.ceil(self.fraction * len(self.finished)))
tops = self.finished[:cutoff]
bottoms = self.finished[self.finished_trials - cutoff:]
for bottom in bottoms:
top = np.random.choice(tops)
exploit_and_explore(bottom, top, self.factor, self.resample_probability, self.epoch, self.searchspace_json)
for trial in self.finished:
if trial not in bottoms:
trial.clean_id()
trial.hyper_parameters['load_checkpoint_dir'] = trial.hyper_parameters['save_checkpoint_dir']
trial.hyper_parameters['save_checkpoint_dir'] = os.path.join(trial.checkpoint_dir, str(self.epoch))
self.finished_trials = 0
for _ in range(self.population_size):
trial_info = self.finished.pop()
self.population.append(trial_info)
while self.credit > 0 and self.pos + 1 < len(self.population):
self.credit -= 1
self.pos += 1
parameter_id = self.param_ids.pop()
trial_info = self.population[self.pos]
trial_info.parameter_id = parameter_id
self.running[parameter_id] = trial_info
self.send_trial_callback(parameter_id, split_index(trial_info.hyper_parameters))
self._proceed_next_epoch()

def import_data(self, data):
pass
"""
Parameters
----------
data : json obj
imported data records

Returns
-------
int
the start epoch number after data imported, only used for unittest
"""
if self.running:
logger.warning("Do not support importing data in the middle of experiment")
return
# the following is for experiment resume
_completed_num = 0
epoch_data_dict = {}
for trial_info in data:
logger.info("Process data record %s / %s", _completed_num, len(data))
_completed_num += 1
# simply validate data format
_params = trial_info["parameter"]
_value = trial_info['value']
# assign fake value for failed trials
if not _value:
logger.info("Useless trial data, value is %s, skip this trial data.", _value)
_value = float('inf') if self.optimize_mode == OptimizeMode.Minimize else float('-inf')
_value = extract_scalar_reward(_value)
if 'save_checkpoint_dir' not in _params:
logger.warning("Invalid data record: save_checkpoint_dir is missing, abandon data import.")
return
epoch_num = int(os.path.basename(_params['save_checkpoint_dir']))
if epoch_num not in epoch_data_dict:
epoch_data_dict[epoch_num] = []
epoch_data_dict[epoch_num].append((_params, _value))
if not epoch_data_dict:
logger.warning("No valid epochs, abandon data import.")
return
# figure out start epoch for resume
max_epoch_num = max(epoch_data_dict, key=int)
if len(epoch_data_dict[max_epoch_num]) < self.population_size:
max_epoch_num -= 1
# If there is no a single complete round, no data to import, start from scratch
if max_epoch_num < 0:
logger.warning("No completed epoch, abandon data import.")
return
assert len(epoch_data_dict[max_epoch_num]) == self.population_size
# check existence of trial save checkpoint dir
for params, _ in epoch_data_dict[max_epoch_num]:
if not os.path.isdir(params['save_checkpoint_dir']):
logger.warning("save_checkpoint_dir %s does not exist, data will not be resumed", params['save_checkpoint_dir'])
return
# resume data
self.epoch = max_epoch_num
self.finished_trials = self.population_size
for params, value in epoch_data_dict[max_epoch_num]:
checkpoint_dir = os.path.dirname(params['save_checkpoint_dir'])
self.finished.append(TrialInfo(checkpoint_dir=checkpoint_dir, hyper_parameters=params, score=value))
self._proceed_next_epoch()
logger.info("Successfully import data to PBT tuner, total data: %d, imported data: %d.", len(data), self.population_size)
logger.info("Start from epoch %d ...", self.epoch)
return self.epoch # return for test
57 changes: 57 additions & 0 deletions src/sdk/pynni/tests/test_builtin_tuners.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,62 @@ def search_space_test_all(self, tuner_factory, supported_types=None, ignore_type
logger.info("Full supported search space: %s", full_supported_search_space)
self.search_space_test_one(tuner_factory, full_supported_search_space)

def import_data_test_for_pbt(self):
"""
test1: import data with complete epoch
test2: import data with incomplete epoch
"""
search_space = {
"choice_str": {
"_type": "choice",
"_value": ["cat", "dog", "elephant", "cow", "sheep", "panda"]
}
}
all_checkpoint_dir = os.path.expanduser("~/nni/checkpoint/test/")
population_size = 4
# ===import data at the beginning===
tuner = PBTTuner(
all_checkpoint_dir=all_checkpoint_dir,
population_size=population_size
)
self.assertIsInstance(tuner, Tuner)
tuner.update_search_space(search_space)
save_dirs = [os.path.join(all_checkpoint_dir, str(i), str(0)) for i in range(population_size)]
# create save checkpoint directory
for save_dir in save_dirs:
os.makedirs(save_dir, exist_ok=True)
# for simplicity, omit "load_checkpoint_dir"
data = [{"parameter": {"choice_str": "cat", "save_checkpoint_dir": save_dirs[0]}, "value": 1.1},
{"parameter": {"choice_str": "dog", "save_checkpoint_dir": save_dirs[1]}, "value": {"default": 1.2, "tmp": 2}},
{"parameter": {"choice_str": "cat", "save_checkpoint_dir": save_dirs[2]}, "value": 11},
{"parameter": {"choice_str": "cat", "save_checkpoint_dir": save_dirs[3]}, "value": 7}]
epoch = tuner.import_data(data)
self.assertEqual(epoch, 1)
logger.info("Imported data successfully at the beginning")
shutil.rmtree(all_checkpoint_dir)
# ===import another data at the beginning, test the case when there is an incompleted epoch===
tuner = PBTTuner(
all_checkpoint_dir=all_checkpoint_dir,
population_size=population_size
)
self.assertIsInstance(tuner, Tuner)
tuner.update_search_space(search_space)
for i in range(population_size - 1):
save_dirs.append(os.path.join(all_checkpoint_dir, str(i), str(1)))
for save_dir in save_dirs:
os.makedirs(save_dir, exist_ok=True)
chicm-ms marked this conversation as resolved.
Show resolved Hide resolved
data = [{"parameter": {"choice_str": "cat", "save_checkpoint_dir": save_dirs[0]}, "value": 1.1},
{"parameter": {"choice_str": "dog", "save_checkpoint_dir": save_dirs[1]}, "value": {"default": 1.2, "tmp": 2}},
{"parameter": {"choice_str": "cat", "save_checkpoint_dir": save_dirs[2]}, "value": 11},
{"parameter": {"choice_str": "cat", "save_checkpoint_dir": save_dirs[3]}, "value": 7},
{"parameter": {"choice_str": "cat", "save_checkpoint_dir": save_dirs[4]}, "value": 1.1},
{"parameter": {"choice_str": "dog", "save_checkpoint_dir": save_dirs[5]}, "value": {"default": 1.2, "tmp": 2}},
{"parameter": {"choice_str": "cat", "save_checkpoint_dir": save_dirs[6]}, "value": 11}]
epoch = tuner.import_data(data)
self.assertEqual(epoch, 1)
logger.info("Imported data successfully at the beginning with incomplete epoch")
shutil.rmtree(all_checkpoint_dir)

def import_data_test(self, tuner_factory, stype="choice_str"):
"""
import data at the beginning with number value and dict value
Expand Down Expand Up @@ -297,6 +353,7 @@ def test_pbt(self):
all_checkpoint_dir=os.path.expanduser("~/nni/checkpoint/test/"),
population_size=100
))
self.import_data_test_for_pbt()

def tearDown(self):
file_list = glob.glob("smac3*") + ["param_config_space.pcs", "scenario.txt", "model_path"]
Expand Down