Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add data, model, and algo that FedSAM needs #453

Merged
merged 9 commits into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions federatedscope/contrib/model/fedsam_convnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
'''The implementation of ASAM and SAM are borrowed from
https://github.com/debcaldarola/fedsam
Caldarola, D., Caputo, B., & Ciccone, M.
Improving Generalization in Federated Learning by Seeking Flat Minima,
European Conference on Computer Vision (ECCV) 2022.
'''
import os
import re
from typing import Callable

import numpy as np
import torch
import torch.nn as nn
from PIL import Image

from federatedscope.register import register_model


class Conv2Model(nn.Module):
def __init__(self, num_classes):
super(Conv2Model, self).__init__()
self.num_classes = num_classes

self.layer1 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5),
nn.ReLU(), nn.MaxPool2d(kernel_size=2))

self.layer2 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5),
nn.ReLU(), nn.MaxPool2d(kernel_size=2))

self.classifier = nn.Sequential(nn.Linear(64 * 5 * 5, 384), nn.ReLU(),
nn.Linear(384, 192), nn.ReLU(),
nn.Linear(192, self.num_classes))

self.size = self.model_size()

def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = torch.reshape(x, (x.shape[0], -1))
x = self.classifier(x)
return x

def model_size(self):
tot_size = 0
for param in self.parameters():
tot_size += param.size()[0]
return tot_size


def call_fedsam_conv2(model_config, local_data):
if model_config.type == 'fedsam_conv2':
model = Conv2Model(10)
return model


register_model('fedsam_conv2', call_fedsam_conv2)
4 changes: 2 additions & 2 deletions federatedscope/contrib/splitter/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def __call__(self, dataset, *args, **kwargs):
return data_list


def call_my_splitter(client_num, **kwargs):
if type == 'mysplitter':
def call_my_splitter(splitter_type, client_num, **kwargs):
if splitter_type == 'mysplitter':
splitter = MySplitter(client_num, **kwargs)
return splitter

Expand Down
63 changes: 63 additions & 0 deletions federatedscope/contrib/splitter/fedsam_cifar10_splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
'''TODO: supplement copyright
joneswong marked this conversation as resolved.
Show resolved Hide resolved
'''
import re
import json
import numpy as np

from federatedscope.register import register_splitter
from federatedscope.core.splitters import BaseSplitter


class FedSAM_CIFAR10_Splitter(BaseSplitter):
"""
This splitter split according to what FedSAM provides

Args:
client_num: the dataset will be split into ``client_num`` pieces
alpha (float): Partition hyperparameter in LDA, smaller alpha \
generates more extreme heterogeneous scenario see \
``np.random.dirichlet``
"""
def __init__(self, client_num, alpha=0.5):
self.alpha = alpha
super(FedSAM_CIFAR10_Splitter, self).__init__(client_num)

def __call__(self, dataset, prior=None, **kwargs):
dataset = [ds for ds in dataset]
label = np.array([y for x, y in dataset])
alpha_str = f'{self.alpha:.2f}'
if len(label) == 50000:
filename = \
'data/fedsam_cifar10/data/{}/federated_{}_alpha_{' \
'}.json'.format('train', 'train', alpha_str)
elif len(label) == 10000:
filename = 'data/fedsam_cifar10/data/test/test.json'
with open(filename, 'r') as ips:
content = json.load(ips)
idx_slice = []

def get_idx(name_list):
return [
int(re.findall('img_\d+_label', fn)[0][4:-6])
for fn in name_list
]

if len(label) == 50000:
for uid in range(self.client_num):
idx_slice.append(
get_idx(content['user_data'][str(uid)]['x']))
elif len(label) == 10000:
idx_slice.append(get_idx(content['user_data'][str(100)]['x']))
idx_slice = np.array_split(np.array(idx_slice[0]),
self.client_num)
data_list = [[dataset[idx] for idx in idxs] for idxs in idx_slice]
return data_list


def call_fedsam_cifar10_splitter(splitter_type, client_num, **kwargs):
if splitter_type == 'fedsam_cifar10_splitter':
splitter = FedSAM_CIFAR10_Splitter(client_num, **kwargs)
return splitter


register_splitter('fedsam_cifar10_splitter', call_fedsam_cifar10_splitter)
184 changes: 184 additions & 0 deletions federatedscope/contrib/trainer/sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
'''The implementation of ASAM and SAM are borrowed from
https://github.com/debcaldarola/fedsam
Caldarola, D., Caputo, B., & Ciccone, M.
Improving Generalization in Federated Learning by Seeking Flat Minima,
European Conference on Computer Vision (ECCV) 2022.
'''
from collections import defaultdict
import torch

from federatedscope.core.trainers import BaseTrainer
from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer


class ASAM(object):
def __init__(self, optimizer, model, rho=0.5, eta=0.01):
self.optimizer = optimizer
self.model = model
self.rho = rho
self.eta = eta
self.state = defaultdict(dict)

@torch.no_grad()
def ascent_step(self):
wgrads = []
for n, p in self.model.named_parameters():
if p.grad is None:
continue
t_w = self.state[p].get("eps")
if t_w is None:
t_w = torch.clone(p).detach()
self.state[p]["eps"] = t_w
if 'weight' in n:
t_w[...] = p[...]
t_w.abs_().add_(self.eta)
p.grad.mul_(t_w)
wgrads.append(torch.norm(p.grad, p=2))
wgrad_norm = torch.norm(torch.stack(wgrads), p=2) + 1.e-16
for n, p in self.model.named_parameters():
if p.grad is None:
continue
t_w = self.state[p].get("eps")
if 'weight' in n:
p.grad.mul_(t_w)
eps = t_w
eps[...] = p.grad[...]
eps.mul_(self.rho / wgrad_norm)
p.add_(eps)
self.optimizer.zero_grad()

@torch.no_grad()
def descent_step(self):
for n, p in self.model.named_parameters():
if p.grad is None:
continue
p.sub_(self.state[p]["eps"])
self.optimizer.step()
self.optimizer.zero_grad()


class SAM(ASAM):
@torch.no_grad()
def ascent_step(self):
grads = []
for n, p in self.model.named_parameters():
if p.grad is None:
continue
grads.append(torch.norm(p.grad, p=2))
grad_norm = torch.norm(torch.stack(grads), p=2) + 1.e-16
for n, p in self.model.named_parameters():
if p.grad is None:
continue
eps = self.state[p].get("eps")
if eps is None:
eps = torch.clone(p).detach()
self.state[p]["eps"] = eps
eps[...] = p.grad[...]
eps.mul_(self.rho / grad_norm)
p.add_(eps)
self.optimizer.zero_grad()


class SAMTrainer(BaseTrainer):
def __init__(self, model, data, device, **kwargs):
# NN modules
self.model = model
# FS `ClientData` or your own data
self.data = data
# Device name
self.device = device
# configs
self.kwargs = kwargs
self.config = kwargs['config']
self.optim_config = self.config.train.optimizer
self.sam_config = self.config.trainer.sam

def train(self):
# Criterion & Optimizer
criterion = torch.nn.CrossEntropyLoss().to(self.device)
optimizer = get_optimizer(self.model, **self.optim_config)

# _hook_on_fit_start_init
self.model.to(self.device)
self.model.train()

num_samples, total_loss = self.run_epoch(optimizer, criterion)

# _hook_on_fit_end
return num_samples, self.model.cpu().state_dict(), \
{'loss_total': total_loss, 'avg_loss': total_loss/float(
num_samples)}

def run_epoch(self, optimizer, criterion):
if self.sam_config.adaptive:
minimizer = ASAM(optimizer,
self.model,
rho=self.sam_config.rho,
eta=self.sam_config.eta)
else:
minimizer = SAM(optimizer,
self.model,
rho=self.sam_config.rho,
eta=self.sam_config.eta)
running_loss = 0.0
num_samples = 0
# for inputs, targets in self.trainloader:
for inputs, targets in self.data['train']:
inputs = inputs.to(self.device)
targets = targets.to(self.device)

# Ascent Step
outputs = self.model(inputs)
loss = criterion(outputs, targets)
loss.backward()
minimizer.ascent_step()

# Descent Step
criterion(self.model(inputs), targets).backward()
minimizer.descent_step()

with torch.no_grad():
running_loss += targets.shape[0] * loss.item()

num_samples += targets.shape[0]

return num_samples, running_loss

def evaluate(self, target_data_split_name='test'):
if target_data_split_name != 'test':
return {}

with torch.no_grad():
criterion = torch.nn.CrossEntropyLoss().to(self.device)

self.model.to(self.device)
self.model.eval()
total_loss = num_samples = num_corrects = 0
# _hook_on_batch_start_init
for x, y in self.data[target_data_split_name]:
# _hook_on_batch_forward
x, y = x.to(self.device), y.to(self.device)
pred = self.model(x)
loss = criterion(pred, y)
cor = torch.sum(torch.argmax(pred, dim=-1).eq(y))

# _hook_on_batch_end
total_loss += loss.item() * y.shape[0]
num_samples += y.shape[0]
num_corrects += cor.item()

# _hook_on_fit_end
return {
f'{target_data_split_name}_acc': float(num_corrects) /
float(num_samples),
f'{target_data_split_name}_loss': total_loss,
f'{target_data_split_name}_total': num_samples,
f'{target_data_split_name}_avg_loss': total_loss /
float(num_samples)
}

def update(self, model_parameters, strict=False):
self.model.load_state_dict(model_parameters, strict)

def get_model_para(self):
return self.model.cpu().state_dict()
11 changes: 11 additions & 0 deletions federatedscope/contrib/trainer/sam_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from federatedscope.register import register_trainer
from federatedscope.core.trainers import BaseTrainer


def call_sam_trainer(trainer_type):
if trainer_type == 'sam_trainer':
from federatedscope.contrib.trainer.sam import SAMTrainer
return SAMTrainer


register_trainer('sam_trainer', call_sam_trainer)
3 changes: 2 additions & 1 deletion federatedscope/core/auxiliaries/data_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def get_data(config, client_cfgs=None):
# Load dataset from source files
dataset, modified_config = load_dataset(config)

# Perform translator to non-FL dataset
# Apply translator to non-FL dataset to transform it into its federated
# counterpart
translator = getattr(import_module('federatedscope.core.data'),
DATA_TRANS_MAP[config.data.type.lower()])(
modified_config, client_cfgs)
Expand Down
10 changes: 9 additions & 1 deletion federatedscope/core/auxiliaries/splitter_builder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import logging

import federatedscope.register as register

logger = logging.getLogger(__name__)

try:
from federatedscope.contrib.splitter import *
except ImportError as error:
logger.warning(
f'{error} in `federatedscope.contrib.splitter`, some modules are not '
f'available.')


def get_splitter(config):
"""
Expand Down Expand Up @@ -37,7 +45,7 @@ def get_splitter(config):
kwargs = {}

for func in register.splitter_dict.values():
splitter = func(client_num, **kwargs)
splitter = func(config.data.splitter, client_num, **kwargs)
if splitter is not None:
return splitter
# Delay import
Expand Down
5 changes: 5 additions & 0 deletions federatedscope/core/configs/cfg_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ def extend_training_cfg(cfg):

cfg.trainer.type = 'general'

cfg.trainer.sam = CN()
cfg.trainer.sam.adaptive = False
cfg.trainer.sam.rho = 1.0
cfg.trainer.sam.eta = .0

# ---------------------------------------------------------------------- #
# Training related options
# ---------------------------------------------------------------------- #
Expand Down
Loading