Skip to content

Commit

Permalink
Merge pull request #466 from rayrayraykk/fix_client_cfg
Browse files Browse the repository at this point in the history
  • Loading branch information
joneswong authored Dec 8, 2022
2 parents e2ce0ce + 6058f23 commit 2cb5585
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 13 deletions.
1 change: 0 additions & 1 deletion federatedscope/core/configs/cfg_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def extend_evaluation_cfg(cfg):

# Monitoring, e.g., 'dissim' for B-local dissimilarity
cfg.eval.monitoring = []

cfg.eval.count_flops = True

# ---------------------------------------------------------------------- #
Expand Down
38 changes: 26 additions & 12 deletions federatedscope/core/data/base_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import logging

from scipy.sparse.csc import csc_matrix

from federatedscope.core.data.utils import merge_data
from federatedscope.core.auxiliaries.dataloader_builder import get_dataloader

Expand Down Expand Up @@ -145,13 +148,16 @@ class ClientData(dict):
test: test dataset, which will be converted to ``Dataloader``
Note:
Key ``data`` in ``ClientData`` is the raw dataset.
Key ``{split}_data`` in ``ClientData`` is the raw dataset.
Key ``{split}`` in ``ClientData`` is the dataloader.
"""
SPLIT_NAMES = ['train', 'val', 'test']

def __init__(self, client_cfg, train=None, val=None, test=None, **kwargs):
self.client_cfg = None
self.train = train
self.val = val
self.test = test
self.train_data = train
self.val_data = val
self.test_data = test
self.setup(client_cfg)
if kwargs is not None:
for key in kwargs:
Expand All @@ -168,18 +174,26 @@ def setup(self, new_client_cfg=None):
Returns:
Bool: Status for indicating whether the client_cfg is updated
"""
# if `batch_size` or `shuffle` change, reinstantiate DataLoader
# if `batch_size` or `shuffle` change, re-instantiate DataLoader
if self.client_cfg is not None:
if dict(self.client_cfg.dataloader) == dict(
new_client_cfg.dataloader):
return False

self.client_cfg = new_client_cfg
if self.train is not None:
self['train'] = get_dataloader(self.train, self.client_cfg,
'train')
if self.val is not None:
self['val'] = get_dataloader(self.val, self.client_cfg, 'val')
if self.test is not None:
self['test'] = get_dataloader(self.test, self.client_cfg, 'test')

for split_data, split_name in zip(
[self.train_data, self.val_data, self.test_data],
self.SPLIT_NAMES):
if split_data is not None:
# csc_matrix does not have ``__len__`` attributes
if isinstance(split_data, csc_matrix):
self[split_name] = get_dataloader(split_data,
self.client_cfg,
split_name)
elif len(split_data) > 0:
self[split_name] = get_dataloader(split_data,
self.client_cfg,
split_name)

return True

0 comments on commit 2cb5585

Please sign in to comment.