Skip to content

Commit

Permalink
Merge pull request #463 from rayrayraykk/blocal
Browse files Browse the repository at this point in the history
  • Loading branch information
joneswong authored Dec 8, 2022
2 parents d66044a + 4654ae9 commit e2ce0ce
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 68 deletions.
12 changes: 3 additions & 9 deletions federatedscope/autotune/fedex/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,15 +326,9 @@ def check_and_move_on(self,
train_msg_buffer[client_id][2])

# Trigger the monitor here (for training)
if 'dissim' in self._cfg.eval.monitoring:
from federatedscope.core.auxiliaries.utils import \
calc_blocal_dissim
# TODO: fix load_state_dict
B_val = calc_blocal_dissim(
model.load_state_dict(strict=False), msg_list)
formatted_eval_res = self._monitor.format_eval_res(
B_val, rnd=self.state, role='Server #')
logger.info(formatted_eval_res)
self._monitor.calc_model_metric(self.model.state_dict(),
msg_list,
rnd=self.state)

# Aggregate
agg_info = {
Expand Down
7 changes: 5 additions & 2 deletions federatedscope/core/data/base_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def split(self, dataset):
datadict = self.split_to_client(train, val, test)
return datadict

def split_train_val_test(self, dataset):
def split_train_val_test(self, dataset, cfg=None):
"""
Split dataset to train, val, test if not provided.
Expand All @@ -68,7 +68,10 @@ def split_train_val_test(self, dataset):
"""
from torch.utils.data import Dataset, Subset

splits = self.global_cfg.data.splits
if cfg is not None:
splits = cfg.data.splits
else:
splits = self.global_cfg.data.splits
if isinstance(dataset, tuple):
# No need to split train/val/test for tuple dataset.
error_msg = 'If dataset is tuple, it must contains ' \
Expand Down
2 changes: 1 addition & 1 deletion federatedscope/core/data/dummy_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def split(self, dataset):
else:
# Do not have train/val/test
train, val, test = self.split_train_val_test(
dataset[client_id])
dataset[client_id], client_cfg)
tmp_dict = dict(train=train, val=val, test=test)
# Only for graph-level task, get number of graph labels
if client_cfg.model.task.startswith('graph') and \
Expand Down
68 changes: 68 additions & 0 deletions federatedscope/core/monitors/metric_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def _check_and_parse(self, ctx):
return y_true, y_pred, y_prob


# Metric for performance
def eval_correct(y_true, y_pred, **kwargs):
correct_list = []

Expand Down Expand Up @@ -289,3 +290,70 @@ def eval_imp_ratio(ctx, y_true, y_prob, y_pred, **kwargs):
'std': (None, False),
**dict.fromkeys([f'hits@{n}' for n in range(1, 101)], (eval_hits, True))
}


# Metric for model dissimilarity
def calc_blocal_dissim(last_model, local_updated_models):
"""
Arguments:
last_model (dict): the state of last round.
local_updated_models (list): each element is (data_size, model).
Returns:
dict: b_local_dissimilarity, the measurements proposed in \
"Tian Li, Anit Kumar Sahu, Manzil Zaheer, and et al. Federated \
Optimization in Heterogeneous Networks".
"""
# for k, v in last_model.items():
# print(k, v)
# for i, elem in enumerate(local_updated_models):
# print(i, elem)
local_grads = []
weights = []
local_gnorms = []
for tp in local_updated_models:
weights.append(tp[0])
grads = dict()
gnorms = dict()
for k, v in tp[1].items():
grad = v - last_model[k]
grads[k] = grad
gnorms[k] = torch.sum(grad**2)
local_grads.append(grads)
local_gnorms.append(gnorms)
weights = np.asarray(weights)
weights = weights / np.sum(weights)
avg_gnorms = dict()
global_grads = dict()

for i in range(len(local_updated_models)):
gnorms = local_gnorms[i]
for k, v in gnorms.items():
if k not in avg_gnorms:
avg_gnorms[k] = .0
avg_gnorms[k] += weights[i] * v
grads = local_grads[i]
for k, v in grads.items():
if k not in global_grads:
global_grads[k] = torch.zeros_like(v, dtype=torch.float32)
global_grads[k] += weights[i] * v
b_local_dissimilarity = dict()
for k in avg_gnorms:
b_local_dissimilarity[k] = np.sqrt(
avg_gnorms[k].item() / torch.sum(global_grads[k]**2).item())
return b_local_dissimilarity


def calc_l2_dissim(last_model, local_updated_models):
l2_dissimilarity = dict()
l2_dissimilarity['raw'] = []
for tp in local_updated_models:
grads = dict()
for key, w in tp[1].items():
grad = w - last_model[key]
grads[key] = grad
grad_norm = \
torch.norm(torch.cat([v.flatten() for v in grads.values()])).item()
l2_dissimilarity['raw'].append(grad_norm)
l2_dissimilarity['mean'] = np.mean(l2_dissimilarity['raw'])
return l2_dissimilarity
66 changes: 23 additions & 43 deletions federatedscope/core/monitors/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import shutil
import datetime
from collections import defaultdict
from importlib import import_module

import numpy as np

Expand Down Expand Up @@ -48,6 +49,7 @@ class Monitor(object):
SUPPORTED_FORMS = ['weighted_avg', 'avg', 'fairness', 'raw']

def __init__(self, cfg, monitored_object=None):
self.cfg = cfg
self.log_res_best = {}
self.outdir = cfg.outdir
self.use_wandb = cfg.wandb.use
Expand Down Expand Up @@ -477,54 +479,32 @@ def format_eval_res(self,
return round_formatted_results_raw if return_raw else \
round_formatted_results

def calc_blocal_dissim(self, last_model, local_updated_models):
def calc_model_metric(self, last_model, local_updated_models, rnd):
"""
Arguments:
last_model (dict): the state of last round.
local_updated_models (list): each element is model.
local_updated_models (list): each element is (data_size, model).
Returns:
dict: b_local_dissimilarity, the measurements proposed in \
"Tian Li, Anit Kumar Sahu, Manzil Zaheer, and et al. Federated \
Optimization in Heterogeneous Networks".
"""
# for k, v in last_model.items():
# print(k, v)
# for i, elem in enumerate(local_updated_models):
# print(i, elem)
local_grads = []
weights = []
local_gnorms = []
for tp in local_updated_models:
weights.append(tp[0])
grads = dict()
gnorms = dict()
for k, v in tp[1].items():
grad = v - last_model[k]
grads[k] = grad
gnorms[k] = torch.sum(grad**2)
local_grads.append(grads)
local_gnorms.append(gnorms)
weights = np.asarray(weights)
weights = weights / np.sum(weights)
avg_gnorms = dict()
global_grads = dict()
for i in range(len(local_updated_models)):
gnorms = local_gnorms[i]
for k, v in gnorms.items():
if k not in avg_gnorms:
avg_gnorms[k] = .0
avg_gnorms[k] += weights[i] * v
grads = local_grads[i]
for k, v in grads.items():
if k not in global_grads:
global_grads[k] = torch.zeros_like(v)
global_grads[k] += weights[i] * v
b_local_dissimilarity = dict()
for k in avg_gnorms:
b_local_dissimilarity[k] = np.sqrt(
avg_gnorms[k].item() / torch.sum(global_grads[k]**2).item())
return b_local_dissimilarity
dict: model_metric_dict
"""
model_metric_dict = {}
for metric in self.cfg.eval.monitoring:
func_name = f'calc_{metric}'
calc_metric = getattr(
import_module(
'federatedscope.core.monitors.metric_calculator'),
func_name)
metric_value = calc_metric(last_model, local_updated_models)
model_metric_dict[f'train_{metric}'] = metric_value
formatted_log = {
'Role': 'Server #',
'Round': rnd,
'Results_model_metric': model_metric_dict
}
logger.info(formatted_log)

return model_metric_dict

def convert_size(self, size_bytes):
"""
Expand Down
10 changes: 3 additions & 7 deletions federatedscope/core/workers/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,13 +445,9 @@ def _perform_federated_aggregation(self):
staleness.append((client_id, self.state - state))

# Trigger the monitor here (for training)
if 'dissim' in self._cfg.eval.monitoring:
# TODO: fix this
B_val = self._monitor.calc_blocal_dissim(
model.load_state_dict(strict=False), msg_list)
formatted_eval_res = self._monitor.format_eval_res(
B_val, rnd=self.state, role='Server #')
logger.info(formatted_eval_res)
self._monitor.calc_model_metric(self.model.state_dict(),
msg_list,
rnd=self.state)

# Aggregate
aggregated_num = len(msg_list)
Expand Down
9 changes: 3 additions & 6 deletions federatedscope/gfl/fedsageplus/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,9 @@ def check_and_move_on(self, check_eval_result=False):
msg_list.append(train_msg_buffer[client_id])

# Trigger the monitor here (for training)
if 'dissim' in self._cfg.eval.monitoring:
B_val = self._monitor.calc_blocal_dissim(
self.model.load_state_dict(), msg_list)
formatted_logs = self._monitor.format_eval_res(
B_val, rnd=self.state, role='Server #')
logger.info(formatted_logs)
self._monitor.calc_model_metric(self.model.state_dict(),
msg_list,
rnd=self.state)

# Aggregate
agg_info = {
Expand Down

0 comments on commit e2ce0ce

Please sign in to comment.