Skip to content

Commit

Permalink
Add checks for completeness of msg_handler (#388)
Browse files Browse the repository at this point in the history
  • Loading branch information
rayrayraykk authored Nov 24, 2022
1 parent 4eeff57 commit 153d363
Show file tree
Hide file tree
Showing 10 changed files with 208 additions and 36 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_distribute.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ on: [push, pull_request]
jobs:
run:
runs-on: ${{ matrix.os }}
timeout-minutes: 10
timeout-minutes: 20
strategy:
matrix:
os: [ubuntu-latest]
Expand Down
30 changes: 24 additions & 6 deletions federatedscope/contrib/worker/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,36 @@


# Build your worker here.
class MyClient(Client):
pass
class MyServer(Server):
def _register_default_handlers(self):
self.register_handlers('join_in', self.callback_funcs_for_join_in,
['assign_client_id', 'address', 'model_para'])
self.register_handlers('join_in_info', self.callback_funcs_for_join_in,
['address', 'model_para'])
self.register_handlers('model_para', self.callback_funcs_model_para,
['model_para', 'evaluate', 'finish'])
self.register_handlers('metrics', self.callback_funcs_for_metrics,
['converged'])


class MyServer(Server):
pass
class MyClient(Client):
def _register_default_handlers(self):
self.register_handlers('assign_client_id',
self.callback_funcs_for_assign_id, [None])
self.register_handlers('address', self.callback_funcs_for_address)
self.register_handlers('model_para',
self.callback_funcs_for_model_para,
['model_para', 'ss_model_para'])
self.register_handlers('evaluate', self.callback_funcs_for_evaluate,
['metrics'])
self.register_handlers('finish', self.callback_funcs_for_finish)
self.register_handlers('converged', self.callback_funcs_for_converged)


def call_my_worker(method):
if method == 'mymethod':
if method == 'myfedavg':
worker_builder = {'client': MyClient, 'server': MyServer}
return worker_builder


register_worker('mymethod', call_my_worker)
register_worker('myfedavg', call_my_worker)
3 changes: 3 additions & 0 deletions federatedscope/core/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,9 @@ def init_global_cfg(cfg):
# Whether to use GPU
cfg.use_gpu = False

# Whether to check the completeness of msg_handler
cfg.check_completeness = False

# Whether to print verbose logging info
cfg.verbose = 1

Expand Down
126 changes: 126 additions & 0 deletions federatedscope/core/fed_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def __init__(self,
self.resource_info = get_resource_info(
config.federate.resource_info_file)

# Check the completeness of msg_handler.
self.check()

# Set up for Runner
self._set_up()

Expand Down Expand Up @@ -208,6 +211,65 @@ def _setup_client(self,

return client

def check(self):
"""
Check the completeness of Server and Client.
"""
if not self.cfg.check_completeness:
return
try:
import os
import networkx as nx
import matplotlib.pyplot as plt
# Build check graph
G = nx.DiGraph()
flags = {0: 'Client', 1: 'Server'}
msg_handler_dicts = [
self.client_class.get_msg_handler_dict(),
self.server_class.get_msg_handler_dict()
]
for flag, msg_handler_dict in zip(flags.keys(), msg_handler_dicts):
role, oppo = flags[flag], flags[(flag + 1) % 2]
for msg_in, (handler, msgs_out) in \
msg_handler_dict.items():
for msg_out in msgs_out:
msg_in_key = f'{oppo}_{msg_in}'
handler_key = f'{role}_{handler}'
msg_out_key = f'{role}_{msg_out}'
G.add_node(msg_in_key, subset=1)
G.add_node(handler_key, subset=0 if flag else 2)
G.add_node(msg_out_key, subset=1)
G.add_edge(msg_in_key, handler_key)
G.add_edge(handler_key, msg_out_key)
pos = nx.multipartite_layout(G)
plt.figure(figsize=(20, 15))
nx.draw(G,
pos,
with_labels=True,
node_color='white',
node_size=800,
width=1.0,
arrowsize=25,
arrowstyle='->')
fig_path = os.path.join(self.cfg.outdir, 'msg_handler.png')
plt.savefig(fig_path)
if nx.has_path(G, 'Client_join_in', 'Server_finish'):
if nx.is_weakly_connected(G):
logger.info(f'Completeness check passes! Save check '
f'results in {fig_path}.')
else:
logger.warning(f'Completeness check raises warning for '
f'some handlers not in FL process! Save '
f'check results in {fig_path}.')
else:
logger.error(f'Completeness check fails for there is no'
f'path from `join_in` to `finish`! Save '
f'check results in {fig_path}.')
except Exception as error:
logger.warning(f'Completeness check failed for {error}!')
return


class StandaloneRunner(BaseRunner):
def _set_up(self):
Expand Down Expand Up @@ -528,6 +590,10 @@ def __init__(self,
self.resource_info = get_resource_info(
config.federate.resource_info_file)

# Check the completeness of msg_handler.
self.check()

def setup(self):
if self.mode == 'standalone':
self.shared_comm_queue = deque()
self._setup_for_standalone()
Expand Down Expand Up @@ -635,6 +701,7 @@ def run(self):
For the standalone mode, a shared message queue will be set up to
simulate ``receiving message``.
"""
self.setup()
if self.mode == 'standalone':
# trigger the FL course
for each_client in self.client:
Expand Down Expand Up @@ -870,3 +937,62 @@ def _handle_msg(self, msg, rcv=-1):
self.client[each_receiver].msg_handlers[msg.msg_type](msg)
self.client[each_receiver]._monitor.track_download_bytes(
download_bytes)

def check(self):
"""
Check the completeness of Server and Client.
"""
if not self.cfg.check_completeness:
return
try:
import os
import networkx as nx
import matplotlib.pyplot as plt
# Build check graph
G = nx.DiGraph()
flags = {0: 'Client', 1: 'Server'}
msg_handler_dicts = [
self.client_class.get_msg_handler_dict(),
self.server_class.get_msg_handler_dict()
]
for flag, msg_handler_dict in zip(flags.keys(), msg_handler_dicts):
role, oppo = flags[flag], flags[(flag + 1) % 2]
for msg_in, (handler, msgs_out) in \
msg_handler_dict.items():
for msg_out in msgs_out:
msg_in_key = f'{oppo}_{msg_in}'
handler_key = f'{role}_{handler}'
msg_out_key = f'{role}_{msg_out}'
G.add_node(msg_in_key, subset=1)
G.add_node(handler_key, subset=0 if flag else 2)
G.add_node(msg_out_key, subset=1)
G.add_edge(msg_in_key, handler_key)
G.add_edge(handler_key, msg_out_key)
pos = nx.multipartite_layout(G)
plt.figure(figsize=(20, 15))
nx.draw(G,
pos,
with_labels=True,
node_color='white',
node_size=800,
width=1.0,
arrowsize=25,
arrowstyle='->')
fig_path = os.path.join(self.cfg.outdir, 'msg_handler.png')
plt.savefig(fig_path)
if nx.has_path(G, 'Client_join_in', 'Server_finish'):
if nx.is_weakly_connected(G):
logger.info(f'Completeness check passes! Save check '
f'results in {fig_path}.')
else:
logger.warning(f'Completeness check raises warning for '
f'some handlers not in FL process! Save '
f'check results in {fig_path}.')
else:
logger.error(f'Completeness check fails for there is no'
f'path from `join_in` to `finish`! Save '
f'check results in {fig_path}.')
except Exception as error:
logger.warning(f'Completeness check failed for {error}!')
return
29 changes: 18 additions & 11 deletions federatedscope/core/workers/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
class BaseClient(Worker):
def __init__(self, ID, state, config, model, strategy):
super(BaseClient, self).__init__(ID, state, config, model, strategy)
# TODO: move to worker
self.msg_handlers = dict()
self.msg_handlers_str = dict()

# TODO: move to worker
def register_handlers(self, msg_type, callback_func):
def register_handlers(self, msg_type, callback_func, send_msg=[None]):
"""
To bind a message type with a handling function.
Expand All @@ -19,6 +18,7 @@ def register_handlers(self, msg_type, callback_func):
message
"""
self.msg_handlers[msg_type] = callback_func
self.msg_handlers_str[msg_type] = (callback_func.__name__, send_msg)

def _register_default_handlers(self):
"""
Expand All @@ -43,17 +43,24 @@ def _register_default_handlers(self):
============================ ==================================
"""
self.register_handlers('assign_client_id',
self.callback_funcs_for_assign_id)
self.callback_funcs_for_assign_id, [None])
self.register_handlers('ask_for_join_in_info',
self.callback_funcs_for_join_in_info)
self.register_handlers('address', self.callback_funcs_for_address)
self.callback_funcs_for_join_in_info,
['join_in_info'])
self.register_handlers('address', self.callback_funcs_for_address,
[None])
self.register_handlers('model_para',
self.callback_funcs_for_model_para)
self.callback_funcs_for_model_para,
['model_para', 'ss_model_para'])
self.register_handlers('ss_model_para',
self.callback_funcs_for_model_para)
self.register_handlers('evaluate', self.callback_funcs_for_evaluate)
self.register_handlers('finish', self.callback_funcs_for_finish)
self.register_handlers('converged', self.callback_funcs_for_converged)
self.callback_funcs_for_model_para,
['ss_model_para', 'model_para'])
self.register_handlers('evaluate', self.callback_funcs_for_evaluate,
['metrics'])
self.register_handlers('finish', self.callback_funcs_for_finish,
[None])
self.register_handlers('converged', self.callback_funcs_for_converged,
[None])

@abc.abstractmethod
def run(self):
Expand Down
19 changes: 12 additions & 7 deletions federatedscope/core/workers/base_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
class BaseServer(Worker):
def __init__(self, ID, state, config, model, strategy):
super(BaseServer, self).__init__(ID, state, config, model, strategy)
# TODO: move to worker
self.msg_handlers = dict()
self.msg_handlers_str = dict()

# TODO: move to worker
def register_handlers(self, msg_type, callback_func):
def register_handlers(self, msg_type, callback_func, send_msg=[None]):
"""
To bind a message type with a handling function.
Expand All @@ -19,6 +18,7 @@ def register_handlers(self, msg_type, callback_func):
message
"""
self.msg_handlers[msg_type] = callback_func
self.msg_handlers_str[msg_type] = (callback_func.__name__, send_msg)

def _register_default_handlers(self):
"""
Expand All @@ -38,10 +38,15 @@ def _register_default_handlers(self):
``metrics`` ``callback_funcs_for_metrics``
============================ ==================================
"""
self.register_handlers('join_in', self.callback_funcs_for_join_in)
self.register_handlers('join_in_info', self.callback_funcs_for_join_in)
self.register_handlers('model_para', self.callback_funcs_model_para)
self.register_handlers('metrics', self.callback_funcs_for_metrics)
self.register_handlers('join_in', self.callback_funcs_for_join_in, [
'assign_client_id', 'ask_for_join_in_info', 'address', 'model_para'
])
self.register_handlers('join_in_info', self.callback_funcs_for_join_in,
['address', 'model_para'])
self.register_handlers('model_para', self.callback_funcs_model_para,
['model_para', 'evaluate', 'finish'])
self.register_handlers('metrics', self.callback_funcs_for_metrics,
['converged'])

@abc.abstractmethod
def run(self):
Expand Down
5 changes: 3 additions & 2 deletions federatedscope/core/workers/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ def __init__(self, ID=-1, state=0, config=None, model=None, strategy=None):
self._model = model
self._cfg = config
self._strategy = strategy
self._mode = self._cfg.federate.mode.lower()
self._monitor = Monitor(config, monitored_object=self)
if self._cfg is not None:
self._mode = self._cfg.federate.mode.lower()
self._monitor = Monitor(config, monitored_object=self)

@property
def ID(self):
Expand Down
14 changes: 10 additions & 4 deletions federatedscope/core/workers/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,13 @@ def __init__(self,
is_unseen_client=False,
*args,
**kwargs):

super(Client, self).__init__(ID, state, config, model, strategy)
# Register message handlers
self._register_default_handlers()

# Un-configured worker
if config is None:
return

# the unseen_client indicates that whether this client contributes to
# FL process by training on its local data and uploading the local
Expand Down Expand Up @@ -109,9 +114,6 @@ def __init__(self,
)) if self._cfg.federate.use_ss else None
self.msg_buffer = {'train': dict(), 'eval': dict()}

# Register message handlers
self._register_default_handlers()

# Communication and communication ability
if 'resource_info' in kwargs and kwargs['resource_info'] is not None:
self.comp_speed = float(
Expand Down Expand Up @@ -527,3 +529,7 @@ def callback_funcs_for_converged(self, message: Message):
message: The received message
"""
self._monitor.global_converged()

@classmethod
def get_msg_handler_dict(cls):
return cls().msg_handlers_str
14 changes: 10 additions & 4 deletions federatedscope/core/workers/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,13 @@ def __init__(self,
strategy=None,
unseen_clients_id=None,
**kwargs):

super(Server, self).__init__(ID, state, config, model, strategy)
# Register message handlers
self._register_default_handlers()

# Un-configured worker
if config is None:
return

self.data = data
self.device = device
Expand Down Expand Up @@ -186,9 +191,6 @@ def __init__(self,
self.client_resource_info = kwargs['client_resource_info'] \
if 'client_resource_info' in kwargs else None

# Register message handlers
self._register_default_handlers()

# Initialize communication manager and message buffer
self.msg_buffer = {'train': dict(), 'eval': dict()}
self.staled_msg_buffer = list()
Expand Down Expand Up @@ -987,3 +989,7 @@ def callback_funcs_for_metrics(self, message: Message):
self.msg_buffer['eval'][rnd][sender] = content

return self.check_and_move_on(check_eval_result=True)

@classmethod
def get_msg_handler_dict(cls):
return cls().msg_handlers_str
Loading

0 comments on commit 153d363

Please sign in to comment.