-
Notifications
You must be signed in to change notification settings - Fork 223
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #509 from TheSunWillRise/phpo
- Loading branch information
Showing
21 changed files
with
2,942 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from federatedscope.autotune.fts.server import FTSServer | ||
from federatedscope.autotune.fts.client import FTSClient | ||
|
||
__all__ = ['FTSServer', 'FTSClient'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,238 @@ | ||
import os | ||
import logging | ||
import json | ||
import copy | ||
import pickle | ||
import numpy as np | ||
|
||
from federatedscope.core.message import Message | ||
from federatedscope.core.workers import Client | ||
|
||
from federatedscope.autotune.fts.utils import * | ||
from federatedscope.autotune.utils import parse_search_space | ||
from federatedscope.core.auxiliaries.trainer_builder import get_trainer | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class FTSClient(Client): | ||
def __init__(self, | ||
ID=-1, | ||
server_id=None, | ||
state=-1, | ||
config=None, | ||
data=None, | ||
model=None, | ||
device='cpu', | ||
strategy=None, | ||
is_unseen_client=False, | ||
*args, | ||
**kwargs): | ||
super(FTSClient, | ||
self).__init__(ID, server_id, state, config, data, model, device, | ||
strategy, is_unseen_client, *args, **kwargs) | ||
self.data = data | ||
self.model = model | ||
self.device = device | ||
self._diff = config.hpo.fts.diff | ||
self._init_model = copy.deepcopy(model) | ||
|
||
# local file paths | ||
self.local_bo_path = os.path.join(self._cfg.hpo.working_folder, | ||
"local_bo_" + str(self.ID) + ".pkl") | ||
self.local_init_path = os.path.join( | ||
self._cfg.hpo.working_folder, | ||
"local_init_" + str(self.ID) + ".pkl") | ||
self.local_info_path = os.path.join( | ||
self._cfg.hpo.working_folder, "local_info_" + str(self.ID) + | ||
"_M_" + str(self._cfg.hpo.fts.M) + ".pkl") | ||
|
||
# prepare search space and bounds | ||
self._ss = parse_search_space(self._cfg.hpo.fts.ss) | ||
self.dim = len(self._ss) | ||
self.bounds = np.asarray([(0., 1.) for _ in self._ss]) | ||
self.pbounds = {} | ||
for k, v in self._ss.items(): | ||
if not (hasattr(v, 'lower') and hasattr(v, 'upper')): | ||
raise ValueError("Unsupported hyper type {}".format(type(v))) | ||
else: | ||
if v.log: | ||
l, u = np.log10(v.lower), np.log10(v.upper) | ||
else: | ||
l, u = v.lower, v.upper | ||
self.pbounds[k] = (l, u) | ||
|
||
def _apply_hyperparams(self, hyperparams): | ||
"""Apply the given hyperparameters | ||
Arguments: | ||
hyperparams (dict): keys are hyperparameter names \ | ||
and values are specific choices. | ||
""" | ||
|
||
cmd_args = [] | ||
for k, v in hyperparams.items(): | ||
cmd_args.append(k) | ||
cmd_args.append(v) | ||
|
||
self._cfg.defrost() | ||
self._cfg.merge_from_list(cmd_args) | ||
self._cfg.freeze(inform=False) | ||
|
||
self.trainer.ctx.setup_vars() | ||
|
||
def _get_new_trainer(self): | ||
self.model = copy.deepcopy(self._init_model) | ||
self.trainer = get_trainer(model=self.model, | ||
data=self.data, | ||
device=self.device, | ||
config=self._cfg, | ||
is_attacker=self.is_attacker, | ||
monitor=self._monitor) | ||
|
||
def _obj_func(self, x, return_eval=False): | ||
""" | ||
Run local evaluation, return some metric to maximize (e.g. val_acc) | ||
""" | ||
self._get_new_trainer() | ||
|
||
baseline = 5.0 | ||
hyperparams = x2conf(x, self.pbounds, self._ss) | ||
self._apply_hyperparams(hyperparams) | ||
|
||
results_before = self.trainer.evaluate('val') | ||
for _ in range(self._cfg.hpo.fts.local_bo_epochs): | ||
sample_size, model_para_all, results = self.trainer.train() | ||
results_after = self.trainer.evaluate('val') | ||
|
||
if self._diff: | ||
res = results_before['val_avg_loss'] \ | ||
- results_after['val_avg_loss'] | ||
else: | ||
res = baseline - results_after['val_avg_loss'] | ||
if return_eval: | ||
return res, results_after | ||
else: | ||
return res | ||
|
||
def _generate_agent_info(self, rand_feats): | ||
logger.info( | ||
('-' * 20, ' generate info on clinet %d ' % self.ID, '_' * 20)) | ||
v_kernel = self._cfg.hpo.fts.v_kernel | ||
obs_noise = self._cfg.hpo.fts.obs_noise | ||
M = self._cfg.hpo.fts.M | ||
M_target = self._cfg.hpo.fts.M_target | ||
|
||
# run standard BO locally | ||
max_iter = self._cfg.hpo.fts.local_bo_max_iter | ||
gp_opt_schedule = self._cfg.hpo.fts.gp_opt_schedule | ||
pt = np.ones(max_iter + 5) | ||
LocalBO(cid=self.ID, | ||
f=self._obj_func, | ||
bounds=self.bounds, | ||
keys=list(self.pbounds.keys()), | ||
gp_opt_schedule=gp_opt_schedule, | ||
use_init=None, | ||
log_file=self.local_bo_path, | ||
save_init=True, | ||
save_init_file=self.local_init_path, | ||
pt=pt, | ||
P_N=None, | ||
ls=self._cfg.hpo.fts.ls, | ||
var=self._cfg.hpo.fts.var, | ||
g_var=self._cfg.hpo.fts.g_var, | ||
N=self._cfg.federate.client_num - 1, | ||
M_target=M_target).maximize(n_iter=max_iter, init_points=3) | ||
|
||
# generate local RFF information | ||
res = pickle.load(open(self.local_bo_path, "rb")) | ||
ys = np.array(res["all"]["values"]).reshape(-1, 1) | ||
params = np.array(res["all"]["params"]) | ||
xs = np.array(params) | ||
xs, ys = xs[:max_iter], ys[:max_iter] | ||
Phi = np.zeros((xs.shape[0], M)) | ||
|
||
s, b = rand_feats["s"], rand_feats["b"] | ||
for i, x in enumerate(xs): | ||
x = np.squeeze(x).reshape(1, -1) | ||
features = np.sqrt(2 / M) * np.cos(np.squeeze(np.dot(x, s.T)) + b) | ||
features = features / np.sqrt(np.inner(features, features)) | ||
features = np.sqrt(v_kernel) * features | ||
Phi[i, :] = features | ||
|
||
Sigma_t = np.dot(Phi.T, Phi) + obs_noise * np.identity(M) | ||
Sigma_t_inv = np.linalg.inv(Sigma_t) | ||
nu_t = np.dot(np.dot(Sigma_t_inv, Phi.T), ys) | ||
w_samples = np.random.multivariate_normal(np.squeeze(nu_t), | ||
obs_noise * Sigma_t_inv, 1) | ||
pickle.dump(w_samples, open(self.local_info_path, "wb")) | ||
|
||
def callback_funcs_for_model_para(self, message: Message): | ||
round, sender, content = message.state, message.sender, message.content | ||
require_agent_infos = content['require_agent_infos'] | ||
|
||
# generate local info and init then send them to server | ||
if require_agent_infos: | ||
rand_feat = content['random_feats'] | ||
self._generate_agent_info(rand_feat) | ||
agent_info = pickle.load(open(self.local_info_path, "rb")) | ||
agent_init = pickle.load(open(self.local_init_path, "rb")) | ||
content = { | ||
'is_required_agent_info': True, | ||
'agent_info': agent_info, | ||
'agent_init': agent_init, | ||
} | ||
|
||
# local run on given hyper-param and return performance | ||
else: | ||
x_max = content['x_max'] | ||
curr_y, eval_res = self._obj_func(x_max, return_eval=True) | ||
content = { | ||
'is_required_agent_info': False, | ||
'curr_y': curr_y, | ||
} | ||
hyper_param = x2conf(x_max, self.pbounds, self._ss) | ||
logger.info('{Client: %d, ' % self.ID + | ||
'GP_opt_iter: %d, ' % round + 'Params: ' + | ||
str(hyper_param) + ', Perform: ' + str(curr_y) + '}') | ||
logger.info( | ||
self._monitor.format_eval_res(eval_res, | ||
rnd=self.state, | ||
role='Client #{}'.format( | ||
self.ID), | ||
return_raw=True)) | ||
|
||
self.state = round | ||
self.comm_manager.send( | ||
Message(msg_type='model_para', | ||
sender=self.ID, | ||
receiver=[sender], | ||
state=self.state, | ||
content=content)) | ||
|
||
def callback_funcs_for_evaluate(self, message: Message): | ||
round, sender, content = \ | ||
message.state, message.sender, message.content | ||
require_agent_infos = content['require_agent_infos'] | ||
assert not require_agent_infos, \ | ||
"Can not evaluate when there is no agents' information" | ||
|
||
self.state = message.state | ||
self._obj_func(content['x_max']) | ||
|
||
metrics = {} | ||
for split in self._cfg.eval.split: | ||
eval_metrics = self.trainer.evaluate(target_data_split_name=split) | ||
for key in eval_metrics: | ||
if self._cfg.federate.mode == 'distributed': | ||
logger.info('Client #{:d}: (Evaluation ({:s} set) at ' | ||
'Round #{:d}) {:s} is {:.6f}'.format( | ||
self.ID, split, self.state, key, | ||
eval_metrics[key])) | ||
metrics.update(**eval_metrics) | ||
|
||
self.comm_manager.send( | ||
Message(msg_type='metrics', | ||
sender=self.ID, | ||
receiver=[sender], | ||
state=self.state, | ||
content=metrics)) |
Oops, something went wrong.