Skip to content

Commit

Permalink
add iid splitter (#394)
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyxclack authored Oct 17, 2022
1 parent ad7b8e5 commit c1edd8a
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 1 deletion.
3 changes: 3 additions & 0 deletions federatedscope/core/auxiliaries/splitter_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def get_splitter(config):
elif config.data.splitter == 'rand_chunk':
from federatedscope.core.splitters.graph import RandChunkSplitter
splitter = RandChunkSplitter(client_num, **kwargs)
elif config.data.splitter == 'iid':
from federatedscope.core.splitters.generic import IIDSplitter
splitter = IIDSplitter(client_num)
else:
logger.warning(f'Splitter {config.data.splitter} not found.')
splitter = None
Expand Down
3 changes: 2 additions & 1 deletion federatedscope/core/splitters/generic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from federatedscope.core.splitters.generic.lda_splitter import LDASplitter
from federatedscope.core.splitters.generic.iid_splitter import IIDSplitter

__all__ = ['LDASplitter']
__all__ = ['LDASplitter', 'IIDSplitter']
17 changes: 17 additions & 0 deletions federatedscope/core/splitters/generic/iid_splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import numpy as np
from federatedscope.core.splitters import BaseSplitter


class IIDSplitter(BaseSplitter):
def __init__(self, client_num):
super(IIDSplitter, self).__init__(client_num)

def __call__(self, dataset, prior=None):
dataset = [ds for ds in dataset]
np.random.shuffle(dataset)
length = len(dataset)
prop = [1.0 / self.client_num for _ in range(self.client_num)]
prop = (np.cumsum(prop) * length).astype(int)[:-1]
data_list = np.split(dataset, prop)
data_list = [x.tolist() for x in data_list]
return data_list

0 comments on commit c1edd8a

Please sign in to comment.