Skip to content

Commit

Permalink
fix minor bug
Browse files Browse the repository at this point in the history
  • Loading branch information
rayrayraykk committed Feb 10, 2023
1 parent 24844fb commit 5dd8558
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 31 deletions.
77 changes: 50 additions & 27 deletions federatedscope/core/workers/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,18 +291,6 @@ def run(self):
def check_and_move_on(self,
check_eval_result=False,
min_received_num=None):
"""
To check the message_buffer. When enough messages are receiving, \
some events (such as perform aggregation, evaluation, and move to \
the next training round) would be triggered.
Arguments:
check_eval_result (bool): If True, check the message buffer for \
evaluation; and check the message buffer for training \
otherwise.
min_received_num: number of minimal received message, used for \
async mode
"""
if min_received_num is None:
if self._cfg.asyn.use:
min_received_num = self._cfg.asyn.min_received_num
Expand All @@ -324,6 +312,27 @@ def check_and_move_on(self,
aggregated_num = self._perform_federated_aggregation()

self.state += 1

# FedSWA cache model
if self.state == self._cfg.fedswa.start_rnd:
self.swa_models_ws = [
model.state_dict() for model in self.models
]
self.swa_rnd = 1
self.eval_swa = False
elif self.state > \
self._cfg.fedswa.start_rnd and \
(self.state - self._cfg.fedswa.start_rnd) % \
self._cfg.fedswa.freq == 0:
logger.info(f'FedSWA cache {self.swa_rnd} models.')
for model, new_model in zip(self.swa_models_ws,
self.models):
new_model = new_model.state_dict()
for key in model.keys():
model[key] = (model[key] * self.swa_rnd +
new_model[key]) / (self.swa_rnd + 1)
self.swa_rnd += 1

if self.state % self._cfg.eval.freq == 0 and self.state != \
self.total_round_num:
# Evaluate
Expand Down Expand Up @@ -850,43 +859,57 @@ def terminate(self, msg_type='finish'):
content=model_para))

def eval(self):
"""
To conduct evaluation. When ``cfg.federate.make_global_eval=True``, \
a global evaluation is conducted by the server.
"""

if self._cfg.federate.make_global_eval:
# By default, the evaluation is conducted one-by-one for all
# internal models;
# for other cases such as ensemble, override the eval function
for i in range(self.model_num):
trainer = self.trainers[i]

if self.eval_swa:
# Use swa model
fedavg_model_w = self.models[i].state_dict()
self.models[i].load_state_dict(self.swa_models_ws[i])

# Preform evaluation in server
metrics = {}
for split in self._cfg.eval.split:
eval_metrics = trainer.evaluate(
target_data_split_name=split)
metrics.update(**eval_metrics)

formatted_eval_res = self._monitor.format_eval_res(
metrics,
rnd=self.state,
role='Server #',
role='Server SWA#' if self.eval_swa else 'Server #',
forms=self._cfg.eval.report,
return_raw=self._cfg.federate.make_global_eval)
self._monitor.update_best_result(
self.best_results,
formatted_eval_res['Results_raw'],
results_type="server_global_eval")
self.history_results = merge_dict_of_results(
self.history_results, formatted_eval_res)

if self.eval_swa:
# Restore
self.models[i].load_state_dict(fedavg_model_w)
self.best_results = formatted_eval_res['Results_raw']
else:
self._monitor.update_best_result(
self.best_results,
formatted_eval_res['Results_raw'],
results_type="server_global_eval")
self.history_results = merge_dict_of_results(
self.history_results, formatted_eval_res)
self._monitor.save_formatted_results(formatted_eval_res)
logger.info(formatted_eval_res)
self.check_and_save()
else:
if self.eval_swa:
for i in range(self.model_num):
# Use swa model
fedavg_model_w = self.models[i].state_dict()
self.models[i].load_state_dict(self.swa_models_ws[i])
# Preform evaluation in clients
self.broadcast_model_para(msg_type='evaluate',
filter_unseen_clients=False)

if self.eval_swa:
for i in range(self.model_num):
self.models[i].load_state_dict(fedavg_model_w)

def callback_funcs_model_para(self, message: Message):
"""
The handling function for receiving model parameters, which triggers \
Expand Down
2 changes: 1 addition & 1 deletion federatedscope/core/workers/wrapper/fedswa.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def check_and_move_on(self,
model.state_dict() for model in self.models
]
self.swa_rnd = 1
self.eval_swa = False
elif self.state > \
self._cfg.fedswa.start_rnd and \
(self.state - self._cfg.fedswa.start_rnd) % \
Expand Down Expand Up @@ -212,6 +211,7 @@ def save_best_results(self):
self._monitor.save_formatted_results(formatted_best_res)

# Bind method to instance
setattr(server, 'eval_swa', False)
server.check_and_move_on = types.MethodType(check_and_move_on, server)
server.eval = types.MethodType(eval, server)
server.check_and_save = types.MethodType(check_and_save, server)
Expand Down
6 changes: 3 additions & 3 deletions scripts/example_configs/fedswa.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ early_stop:
seed: 12345
federate:
mode: standalone
total_round_num: 6
total_round_num: 100
sample_client_rate: 0.1
make_global_eval: True
merge_test_data: True
merge_val_data: True
fedswa:
use: True
freq: 3
start_rnd: 2
freq: 10
start_rnd: 75
data:
root: data/
type: femnist
Expand Down

0 comments on commit 5dd8558

Please sign in to comment.