Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add checks for completeness of msg_handler #388

Merged
merged 9 commits into from
Nov 24, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions federatedscope/core/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,9 @@ def init_global_cfg(cfg):
# Whether to use GPU
cfg.use_gpu = False

# Whether to check the completeness of msg_handler
cfg.check_completeness = False

# Whether to print verbose logging info
cfg.verbose = 1

Expand Down
64 changes: 64 additions & 0 deletions federatedscope/core/fed_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ def __init__(self,
self.resource_info = get_resource_info(
config.federate.resource_info_file)

# Check the completeness of msg_handler.
self.check()

def setup(self):
if self.mode == 'standalone':
self.shared_comm_queue = deque()
self._setup_for_standalone()
Expand Down Expand Up @@ -184,6 +188,7 @@ def run(self):
For the standalone mode, a shared message queue will be set up to
simulate ``receiving message``.
"""
self.setup()
if self.mode == 'standalone':
# trigger the FL course
for each_client in self.client:
Expand Down Expand Up @@ -427,3 +432,62 @@ def _handle_msg(self, msg, rcv=-1):
self.client[each_receiver].msg_handlers[msg.msg_type](msg)
self.client[each_receiver]._monitor.track_download_bytes(
download_bytes)

def check(self):
"""
Check the completeness of Server and Client.

Returns:

"""
if self.cfg.check_completeness:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is necessary to indent such a huge block of code

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update accordingly.

try:
import os
import networkx as nx
import matplotlib.pyplot as plt
# Build check graph
G = nx.DiGraph()
flags = {0: 'Client', 1: 'Server'}
msg_handler_dicts = [
self.client_class.get_msg_handler_dict(),
self.server_class.get_msg_handler_dict()
]
for flag, msg_handler_dict in zip(flags.keys(),
msg_handler_dicts):
role, oppo = flags[flag], flags[(flag + 1) % 2]
for msg_in, (handler, msgs_out) in \
msg_handler_dict.items():
for msg_out in msgs_out:
msg_in_key = f'{oppo}_{msg_in}'
handler_key = f'{role}_{handler}'
msg_out_key = f'{role}_{msg_out}'
G.add_node(msg_in_key, subset=1)
G.add_node(handler_key, subset=0 if flag else 2)
G.add_node(msg_out_key, subset=1)
G.add_edge(msg_in_key, handler_key)
G.add_edge(handler_key, msg_out_key)
pos = nx.multipartite_layout(G)
plt.figure(figsize=(20, 15))
nx.draw(G,
pos,
with_labels=True,
node_color='white',
node_size=800)
fig_path = os.path.join(self.cfg.outdir, 'msg_handler.png')
plt.savefig(fig_path)
if nx.has_path(G, 'Client_join_in', 'Server_finish'):
if nx.is_weakly_connected(G):
logger.info(f'Completeness check passes! Save check '
f'results in {fig_path}.')
else:
logger.warning(
f'Completeness check raises warning for '
f'some handlers not in FL process! Save '
f'check results in {fig_path}.')
else:
logger.error(f'Completeness check fails for there is no'
f'path from `join_in` to `finish`! Save '
f'check results in {fig_path}.')
except Exception as error:
logger.warning(f'Completeness check failed for {error}!')
return
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the reason is that we cannot say yes or not about the correctness/completeness, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the document shows, the correctness/completeness checks only check the message-handler pairs and raise three types of logs (info, warning, and error). If something goes wrong with Python code, we'd better keep the exception stack as it is. So the return value is meaningless.

31 changes: 21 additions & 10 deletions federatedscope/core/workers/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,13 @@ def __init__(self,
is_unseen_client=False,
*args,
**kwargs):
# Register message handlers
self.msg_handlers = dict()
self.msg_handlers_str = dict()
self._register_default_handlers()

if config is None:
return
super(Client, self).__init__(ID, state, config, model, strategy)

# the unseen_client indicates that whether this client contributes to
Expand Down Expand Up @@ -91,10 +97,6 @@ def __init__(self,
)) if self._cfg.federate.use_ss else None
self.msg_buffer = {'train': dict(), 'eval': dict()}

# Register message handlers
self.msg_handlers = dict()
self._register_default_handlers()

# Communication and communication ability
if 'resource_info' in kwargs and kwargs['resource_info'] is not None:
self.comp_speed = float(
Expand Down Expand Up @@ -161,7 +163,7 @@ def _calculate_model_delta(self, init_model, updated_model):
else:
return model_deltas[0]

def register_handlers(self, msg_type, callback_func):
def register_handlers(self, msg_type, callback_func, send_msg=[None]):
"""
To bind a message type with a handling function.

Expand All @@ -171,18 +173,23 @@ def register_handlers(self, msg_type, callback_func):
message
"""
self.msg_handlers[msg_type] = callback_func
self.msg_handlers_str[msg_type] = (callback_func.__name__, send_msg)

def _register_default_handlers(self):
self.register_handlers('assign_client_id',
self.callback_funcs_for_assign_id)
self.callback_funcs_for_assign_id, [None])
self.register_handlers('ask_for_join_in_info',
self.callback_funcs_for_join_in_info)
self.callback_funcs_for_join_in_info,
['join_in_info'])
self.register_handlers('address', self.callback_funcs_for_address)
self.register_handlers('model_para',
self.callback_funcs_for_model_para)
self.callback_funcs_for_model_para,
['model_para'])
self.register_handlers('ss_model_para',
self.callback_funcs_for_model_para)
self.register_handlers('evaluate', self.callback_funcs_for_evaluate)
self.callback_funcs_for_model_para,
['ss_model_para'])
self.register_handlers('evaluate', self.callback_funcs_for_evaluate,
['metric'])
self.register_handlers('finish', self.callback_funcs_for_finish)
self.register_handlers('converged', self.callback_funcs_for_converged)

Expand Down Expand Up @@ -534,3 +541,7 @@ def callback_funcs_for_converged(self, message: Message):
"""

self._monitor.global_converged()

@classmethod
def get_msg_handler_dict(cls):
return cls().msg_handlers_str
33 changes: 24 additions & 9 deletions federatedscope/core/workers/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,13 @@ def __init__(self,
strategy=None,
unseen_clients_id=None,
**kwargs):
# Register message handlers
self.msg_handlers = dict()
self.msg_handlers_str = dict()
self._register_default_handlers()

if config is None:
return
super(Server, self).__init__(ID, state, config, model, strategy)

self.data = data
Expand Down Expand Up @@ -165,10 +171,6 @@ def __init__(self,
self.client_resource_info = kwargs['client_resource_info'] \
if 'client_resource_info' in kwargs else None

# Register message handlers
self.msg_handlers = dict()
self._register_default_handlers()

# Initialize communication manager and message buffer
self.msg_buffer = {'train': dict(), 'eval': dict()}
self.staled_msg_buffer = list()
Expand Down Expand Up @@ -206,7 +208,7 @@ def total_round_num(self, value):
def register_noise_injector(self, func):
self._noise_injector = func

def register_handlers(self, msg_type, callback_func):
def register_handlers(self, msg_type, callback_func, send_msg=[None]):
"""
To bind a message type with a handling function.

Expand All @@ -216,12 +218,21 @@ def register_handlers(self, msg_type, callback_func):
message
"""
self.msg_handlers[msg_type] = callback_func
self.msg_handlers_str[msg_type] = (callback_func.__name__, send_msg)

def _register_default_handlers(self):
self.register_handlers('join_in', self.callback_funcs_for_join_in)
self.register_handlers('join_in_info', self.callback_funcs_for_join_in)
self.register_handlers('model_para', self.callback_funcs_model_para)
self.register_handlers('metrics', self.callback_funcs_for_metrics)
self.register_handlers('join_in', self.callback_funcs_for_join_in, [
'assign_client_id', 'ask_for_join_in_info', 'address', 'model_para'
])
self.register_handlers('join_in_info', self.callback_funcs_for_join_in,
[
'assign_client_id', 'ask_for_join_in_info',
'address', 'model_para'
])
self.register_handlers('model_para', self.callback_funcs_model_para,
['model_para', 'finish'])
self.register_handlers('metrics', self.callback_funcs_for_metrics,
['converged'])

def run(self):
"""
Expand Down Expand Up @@ -981,3 +992,7 @@ def callback_funcs_for_metrics(self, message: Message):
self.msg_buffer['eval'][round][sender] = content

return self.check_and_move_on(check_eval_result=True)

@classmethod
def get_msg_handler_dict(cls):
return cls().msg_handlers_str
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

test_requires = []

dev_requires = test_requires + ['pre-commit']
dev_requires = test_requires + ['pre-commit', 'networkx', 'matplotlib']

org_requires = ['paramiko==2.11.0', 'celery[redis]', 'cmd2']

Expand Down