Skip to content

Commit

Permalink
Rewrite ParEGO (#76)
Browse files Browse the repository at this point in the history
* rewrite parego
---------

Co-authored-by: Jhj <[email protected]>
  • Loading branch information
dusixian and jhj0411jhj authored Jan 4, 2024
1 parent eba9c97 commit 2db2af0
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 6 deletions.
9 changes: 9 additions & 0 deletions openbox/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ def build_surrogate(func_str='gp', config_space=None, rng=None, transfer_learnin
func_str = func_str.lower()
types, bounds = get_types(config_space)
seed = rng.randint(MAXINT)

if func_str.startswith('parego_'):
func_str = func_str[7:]
base_surrogate = build_surrogate(
func_str=func_str, config_space=config_space, rng=rng,
transfer_learning_history=transfer_learning_history)
from openbox.surrogate.mo.parego import ParEGOSurrogate
return ParEGOSurrogate(base_surrogate=base_surrogate, seed=seed)

if func_str == 'prf':
try:
from openbox.surrogate.base.rf_with_instances import RandomForestWithInstances
Expand Down
17 changes: 11 additions & 6 deletions openbox/core/generic_advisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,17 @@ def setup_bo_basics(self):
-------
An optimizer object.
"""
if self.num_objectives == 1 or self.acq_type == 'parego':
if self.num_objectives == 1:
self.surrogate_model = build_surrogate(func_str=self.surrogate_type,
config_space=self.config_space,
rng=self.rng,
transfer_learning_history=self.transfer_learning_history)
elif self.acq_type == 'parego':
func_str = 'parego_' + self.surrogate_type
self.surrogate_model = build_surrogate(func_str=func_str,
config_space=self.config_space,
rng=self.rng,
transfer_learning_history=self.transfer_learning_history)
else: # multi-objectives
self.surrogate_model = [build_surrogate(func_str=self.surrogate_type,
config_space=self.config_space,
Expand Down Expand Up @@ -379,10 +385,7 @@ def get_suggestion(self, history: History = None, return_list: bool = False):
if self.num_objectives == 1:
self.surrogate_model.train(X, Y[:, 0])
elif self.acq_type == 'parego':
weights = self.rng.random_sample(self.num_objectives)
weights = weights / np.sum(weights)
scalarized_obj = get_chebyshev_scalarization(weights, Y)
self.surrogate_model.train(X, scalarized_obj(Y))
self.surrogate_model.train(X, Y)
else: # multi-objectives
for i in range(self.num_objectives):
self.surrogate_model[i].train(X, Y[:, i])
Expand All @@ -401,9 +404,11 @@ def get_suggestion(self, history: History = None, return_list: bool = False):
else: # multi-objectives
mo_incumbent_values = history.get_mo_incumbent_values()
if self.acq_type == 'parego':
scalarized_obj = self.surrogate_model.get_scalarized_obj()
incumbent_value = scalarized_obj(np.atleast_2d(mo_incumbent_values))
self.acquisition_function.update(model=self.surrogate_model,
constraint_models=self.constraint_models,
eta=scalarized_obj(np.atleast_2d(mo_incumbent_values)),
eta=incumbent_value,
num_data=num_config_evaluated)
elif self.acq_type.startswith('ehvi'):
partitioning = NondominatedPartitioning(self.num_objectives, Y)
Expand Down
Empty file.
42 changes: 42 additions & 0 deletions openbox/surrogate/mo/parego.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# License: 3-clause BSD
# Copyright (c) 2016-2018, Ml4AAD Group (http://www.ml4aad.org/)

import numpy as np

from openbox import logger
from openbox.utils.multi_objective import get_chebyshev_scalarization


class ParEGOSurrogate(object):
def __init__(self, base_surrogate, seed):
self.base_surrogate = base_surrogate
self.rng = np.random.RandomState(seed)
self.scalarized_obj = None

def train(self, X, Y):
num_objectives = Y.shape[1]

weights = self.rng.dirichlet(alpha=np.ones(num_objectives))
logger.info(f'[ParEGO] Sampled weights: {weights}')
self.scalarized_obj = get_chebyshev_scalarization(weights, Y)
Y_scalarized = self.scalarized_obj(Y)

self.base_surrogate.train(X, Y_scalarized)

def predict(self, X):
return self.base_surrogate.predict(X)

def get_scalarized_obj(self):
return self.scalarized_obj

def predict_marginalized_over_instances(self, X):
if hasattr(self.base_surrogate, "predict_marginalized_over_instances"):
return self.base_surrogate.predict_marginalized_over_instances(X)
else:
raise NotImplementedError("predict_marginalized_over_instances is not implemented for the base surrogate.")

def sample_functions(self, X, n_funcs=1):
if hasattr(self.base_surrogate, "sample_functions"):
return self.base_surrogate.sample_functions(X, n_funcs)
else:
raise NotImplementedError("Sampling functions is not implemented for the base surrogate.")

0 comments on commit 2db2af0

Please sign in to comment.