Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SWA #519

Merged
merged 6 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions federatedscope/core/configs/cfg_fl_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ def extend_fl_algo_cfg(cfg):
cfg.fedprox.use = False
cfg.fedprox.mu = 0.

# ---------------------------------------------------------------------- #
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering whether it is appropriate to put SWA here. It is not an FL algorithm but just a trick to produce another solution from those in the late stage of a training course.

# fedswa related options, Stochastic Weight Averaging (SWA)
# ---------------------------------------------------------------------- #
cfg.fedswa = CN()
cfg.fedswa.use = False
cfg.fedswa.freq = 10
cfg.fedswa.start_rnd = 30

# ---------------------------------------------------------------------- #
# Personalization related options, pFL
# ---------------------------------------------------------------------- #
Expand Down Expand Up @@ -114,5 +122,11 @@ def assert_fl_algo_cfg(cfg):
# By default, use the same lr to normal mode
cfg.personalization.lr = cfg.train.optimizer.lr

if cfg.fedswa.use:
assert cfg.fedswa.start_rnd < cfg.federate.total_round_num, \
f'`cfg.fedswa.start_rnd` {cfg.fedswa.start_rnd} must be smaller ' \
f'than `cfg.federate.total_round_num` ' \
f'{cfg.federate.total_round_num}.'


register_config("fl_algo", extend_fl_algo_cfg)
2 changes: 2 additions & 0 deletions federatedscope/core/configs/cfg_fl_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def extend_fl_setting_cfg(cfg):
cfg.federate.merge_test_data = False # For efficient simulation, users
# can choose to merge the test data and perform global evaluation,
# instead of perform test at each client
cfg.federate.merge_val_data = False # Enabled only when
# `merge_test_data` is True, also for efficient simulation

# the method name is used to internally determine composition of
# different aggregators, messages, handlers, etc.,
Expand Down
5 changes: 4 additions & 1 deletion federatedscope/core/data/base_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,13 @@ def preprocess(self, datadict):
datadict: dict with `client_id` as key, `ClientData` as value.
"""
if self.global_cfg.federate.merge_test_data:
merge_split = ['test']
if self.global_cfg.federate.merge_val_data:
merge_split += ['val']
server_data = merge_data(
all_data=datadict,
merged_max_data_id=self.global_cfg.federate.client_num,
specified_dataset_name=['test'])
specified_dataset_name=merge_split)
# `0` indicate Server
datadict[0] = ClientData(self.global_cfg, **server_data)

Expand Down
7 changes: 5 additions & 2 deletions federatedscope/core/fed_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ def _setup_server(self, resource_info=None, client_resource_info=None):
if self.cfg.vertical.use:
from federatedscope.vertical_fl.utils import wrap_vertical_server
server = wrap_vertical_server(server, self.cfg)
if self.cfg.fedswa.use:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation implies that, in the future, we'd better refactor the workers so that we would not need to make another worker class just for adding such a trick.

from federatedscope.core.workers.wrapper import wrap_swa_server
server = wrap_swa_server(server)
logger.info('Server has been set up ... ')
return self.feat_engr_wrapper_server(server)

Expand Down Expand Up @@ -869,7 +872,7 @@ def _setup_server(self, resource_info=None, client_resource_info=None):

logger.info('Server has been set up ... ')

return server
return self.feat_engr_wrapper_server(server)
rayrayraykk marked this conversation as resolved.
Show resolved Hide resolved

def _setup_client(self,
client_id=-1,
Expand Down Expand Up @@ -926,7 +929,7 @@ def _setup_client(self,
else:
logger.info(f'Client {client_id} has been set up ... ')

return client
return self.feat_engr_wrapper_client(client)
rayrayraykk marked this conversation as resolved.
Show resolved Hide resolved

def _handle_msg(self, msg, rcv=-1):
"""
Expand Down
3 changes: 3 additions & 0 deletions federatedscope/core/workers/wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from federatedscope.core.workers.wrapper.fedswa import wrap_swa_server

__all__ = ['wrap_swa_server']
220 changes: 220 additions & 0 deletions federatedscope/core/workers/wrapper/fedswa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
import types
import logging

from federatedscope.core.message import Message
from federatedscope.core.auxiliaries.utils import merge_dict_of_results

logger = logging.getLogger(__name__)


def wrap_swa_server(server):
def check_and_move_on(self,
check_eval_result=False,
min_received_num=None):
if min_received_num is None:
if self._cfg.asyn.use:
min_received_num = self._cfg.asyn.min_received_num
else:
min_received_num = self._cfg.federate.sample_client_num
assert min_received_num <= self.sample_client_num

if check_eval_result and self._cfg.federate.mode.lower(
) == "standalone":
# in evaluation stage and standalone simulation mode, we assume
# strong synchronization that receives responses from all clients
min_received_num = len(self.comm_manager.get_neighbors().keys())

move_on_flag = True # To record whether moving to a new training
# round or finishing the evaluation
if self.check_buffer(self.state, min_received_num, check_eval_result):
if not check_eval_result:
# Receiving enough feedback in the training process
aggregated_num = self._perform_federated_aggregation()

self.state += 1

# FedSWA cache model
if self.state == self._cfg.fedswa.start_rnd:
self.swa_models_ws = [
model.state_dict() for model in self.models
]
self.swa_rnd = 1
self.eval_swa = False
elif self.state > \
self._cfg.fedswa.start_rnd and \
(self.state - self._cfg.fedswa.start_rnd) % \
self._cfg.fedswa.freq == 0:
logger.info(f'FedSWA cache {self.swa_rnd} models.')
for model, new_model in zip(self.swa_models_ws,
self.models):
new_model = new_model.state_dict()
for key in model.keys():
model[key] = (model[key] * self.swa_rnd +
new_model[key]) / (self.swa_rnd + 1)
self.swa_rnd += 1

if self.state % self._cfg.eval.freq == 0 and self.state != \
self.total_round_num:
# Evaluate
logger.info(f'Server: Starting evaluation at the end '
f'of round {self.state - 1}.')
self.eval()

if self.state < self.total_round_num:
# Move to next round of training
logger.info(
f'----------- Starting a new training round (Round '
f'#{self.state}) -------------')
# Clean the msg_buffer
self.msg_buffer['train'][self.state - 1].clear()
self.msg_buffer['train'][self.state] = dict()
self.staled_msg_buffer.clear()
# Start a new training round
self._start_new_training_round(aggregated_num)
else:
# Final Evaluate
logger.info('Server: Training is finished! Starting '
'evaluation.')
self.eval()

else:
# Receiving enough feedback in the evaluation process
self._merge_and_format_eval_results()

else:
move_on_flag = False

return move_on_flag

def eval(self):
if self._cfg.federate.make_global_eval:
for i in range(self.model_num):
trainer = self.trainers[i]

if self.eval_swa:
# Use swa model
fedavg_model_w = self.models[i].state_dict()
self.models[i].load_state_dict(self.swa_models_ws[i])

# Preform evaluation in server
metrics = {}
for split in self._cfg.eval.split:
eval_metrics = trainer.evaluate(
target_data_split_name=split)
metrics.update(**eval_metrics)

formatted_eval_res = self._monitor.format_eval_res(
metrics,
rnd=self.state,
role='Server SWA#' if self.eval_swa else 'Server #',
forms=self._cfg.eval.report,
return_raw=self._cfg.federate.make_global_eval)

if self.eval_swa:
# Restore
self.models[i].load_state_dict(fedavg_model_w)
self.best_results = formatted_eval_res['Results_raw']
else:
self._monitor.update_best_result(
self.best_results,
formatted_eval_res['Results_raw'],
results_type="server_global_eval")
self.history_results = merge_dict_of_results(
self.history_results, formatted_eval_res)
self._monitor.save_formatted_results(formatted_eval_res)
logger.info(formatted_eval_res)
self.check_and_save()
else:
if self.eval_swa:
for i in range(self.model_num):
# Use swa model
fedavg_model_w = self.models[i].state_dict()
self.models[i].load_state_dict(self.swa_models_ws[i])
# Preform evaluation in clients
self.broadcast_model_para(msg_type='evaluate',
filter_unseen_clients=False)

if self.eval_swa:
for i in range(self.model_num):
self.models[i].load_state_dict(fedavg_model_w)

def check_and_save(self):
"""
To save the results and save model after each evaluation, and check \
whether to early stop.
"""

# early stopping
if "Results_weighted_avg" in self.history_results and \
self._cfg.eval.best_res_update_round_wise_key in \
self.history_results['Results_weighted_avg']:
should_stop = self.early_stopper.track_and_check(
self.history_results['Results_weighted_avg'][
self._cfg.eval.best_res_update_round_wise_key])
elif "Results_avg" in self.history_results and \
self._cfg.eval.best_res_update_round_wise_key in \
self.history_results['Results_avg']:
should_stop = self.early_stopper.track_and_check(
self.history_results['Results_avg'][
self._cfg.eval.best_res_update_round_wise_key])
else:
should_stop = False

if should_stop:
self._monitor.global_converged()
self.comm_manager.send(
Message(
msg_type="converged",
sender=self.ID,
receiver=list(self.comm_manager.neighbors.keys()),
timestamp=self.cur_timestamp,
state=self.state,
))
self.state = self.total_round_num + 1

if should_stop or self.state >= self.total_round_num:
logger.info('Server: Final evaluation is finished! Starting '
'merging results.')
# last round or early stopped
self.save_best_results()
if not self._cfg.federate.make_global_eval:
self.save_client_eval_results()

if self.eval_swa:
self.terminate(msg_type='finish')
else:
self.eval_swa = True
logger.info('Server: Evaluation with FedSWA')
self.eval()

# Clean the clients evaluation msg buffer
if not self._cfg.federate.make_global_eval:
round = max(self.msg_buffer['eval'].keys())
self.msg_buffer['eval'][round].clear()

if self.state == self.total_round_num:
# break out the loop for distributed mode
self.state += 1

def save_best_results(self):
"""
To Save the best evaluation results.
"""
if self._cfg.federate.save_to != '':
self.aggregator.save_model(self._cfg.federate.save_to, self.state)
formatted_best_res = self._monitor.format_eval_res(
results=self.best_results,
rnd="Final",
role='Server SWA#' if self.eval_swa else 'Server #',
forms=["raw"],
return_raw=True)
logger.info(formatted_best_res)
self._monitor.save_formatted_results(formatted_best_res)

# Bind method to instance
server.check_and_move_on = types.MethodType(check_and_move_on, server)
server.eval = types.MethodType(eval, server)
server.check_and_save = types.MethodType(check_and_save, server)
server.save_best_results = types.MethodType(save_best_results, server)

return server
44 changes: 44 additions & 0 deletions scripts/example_configs/fedswa.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use_gpu: True
device: 0
early_stop:
patience: 5
seed: 12345
federate:
mode: standalone
total_round_num: 6
sample_client_rate: 0.1
make_global_eval: True
merge_test_data: True
merge_val_data: True
fedswa:
use: True
freq: 3
start_rnd: 2
data:
root: data/
type: femnist
splits: [0.6,0.2,0.2]
subsample: 0.05
transform: [['ToTensor'], ['Normalize', {'mean': [0.9637], 'std': [0.1592]}]]
dataloader:
batch_size: 10
model:
type: convnet2
hidden: 2048
out_channels: 62
train:
local_update_steps: 1
batch_or_epoch: epoch
optimizer:
lr: 0.01
weight_decay: 0.0
grad:
grad_clip: 5.0
criterion:
type: CrossEntropyLoss
trainer:
type: cvtrainer
eval:
freq: 10
metrics: ['acc', 'correct']
count_flops: False