-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
351 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import tensorflow as tf | ||
|
||
# Take a look at some samples from the dataset: each class shows some | ||
def showPic(X_train, y_train): | ||
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] | ||
num_classes = len(classes) | ||
samples_per_class = 7 | ||
for y, cls in enumerate(classes): | ||
idxs = np.flatnonzero(y_train == y) | ||
# Randomly pick some from a category | ||
idxs = np.random.choice(idxs, samples_per_class, replace=False) | ||
for i, idx in enumerate(idxs): | ||
plt_idx = i * num_classes + y + 1 | ||
plt.subplot(samples_per_class, num_classes, plt_idx) | ||
plt.imshow(X_train[idx].astype('uint8')) | ||
plt.axis('off') | ||
if i == 0: | ||
plt.title(cls) | ||
plt.show() | ||
|
||
if __name__ == '__main__': | ||
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() | ||
showPic(x_train, y_train) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
{ | ||
"model_name" : "resnet18", | ||
"no_models" : 10, | ||
"type" : "cifar", | ||
"global_epochs" : 20, | ||
"local_epochs" : 3, | ||
"k" : 9, | ||
"batch_size" : 32, | ||
"lr" : 0.001, | ||
"momentum" : 0.0001, | ||
"lambda" : 0.1 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import torch | ||
from torchvision import datasets, transforms | ||
|
||
def get_dataset(dir, name): | ||
if name == 'cifar': | ||
# Set two conversion formats | ||
# transforms.Compose is a combination of multiple transforms (a list of transforms) | ||
transform_train = transforms.Compose([ | ||
# transforms.RandomCrop: 切割中心点的位置随机选取 | ||
transforms.RandomCrop(32, padding=4), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
# transforms.Normalize: 给定均值:(R,G,B) 方差:(R,G,B),将会把Tensor正则化 | ||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | ||
]) | ||
|
||
transform_test = transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | ||
]) | ||
|
||
train_dataset = datasets.CIFAR10(dir, train=True, download=True, | ||
transform=transform_train) | ||
eval_dataset = datasets.CIFAR10(dir, train=False, transform=transform_test) | ||
|
||
return train_dataset, eval_dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import models, torch, copy | ||
|
||
|
||
class Client(object): | ||
|
||
def __init__(self, conf, model, train_dataset, id=-1): | ||
|
||
self.conf = conf | ||
# Client local model (usually transmitted by the server) | ||
self.local_model = models.get_model(self.conf["model_name"]) | ||
self.client_id = id | ||
self.train_dataset = train_dataset | ||
# Split training set by ID | ||
all_range = list(range(len(self.train_dataset))) | ||
data_len = int(len(self.train_dataset) / self.conf['no_models']) | ||
train_indices = all_range[id * data_len: (id + 1) * data_len] | ||
|
||
self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=conf["batch_size"], | ||
sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||
train_indices)) | ||
|
||
# model local training function | ||
def local_train(self, model): | ||
# Overall process: pull the model of the server and get it through training on some local datasets | ||
for name, param in model.state_dict().items(): | ||
# The client first overwrites the local model with the global model issued by the server | ||
self.local_model.state_dict()[name].copy_(param.clone()) | ||
|
||
# print(id(model)) | ||
# Define an optimization function for local model training | ||
optimizer = torch.optim.SGD(self.local_model.parameters(), lr=self.conf['lr'], | ||
momentum=self.conf['momentum']) | ||
# print(id(self.local_model)) | ||
self.local_model.train() | ||
for e in range(self.conf["local_epochs"]): | ||
|
||
for batch_id, batch in enumerate(self.train_loader): | ||
data, target = batch | ||
|
||
if torch.cuda.is_available(): | ||
data = data.cuda() | ||
target = target.cuda() | ||
|
||
optimizer.zero_grad() | ||
output = self.local_model(data) | ||
loss = torch.nn.functional.cross_entropy(output, target) | ||
loss.backward() | ||
|
||
optimizer.step() | ||
print("Epoch %d done." % e) | ||
diff = dict() | ||
for name, data in self.local_model.state_dict().items(): | ||
diff[name] = (data - model.state_dict()[name]) | ||
# print(diff[name]) | ||
print("Client %d local train done" % self.client_id) | ||
|
||
return diff |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# 联邦学习 | ||
# | ||
|
||
import json | ||
import torch, random | ||
from fl_server import * | ||
from fl_client import * | ||
import models, datasets | ||
import time | ||
|
||
# The basic process of horizontal federated learning implemented: | ||
# | ||
# 1. The server generates the initialization model according to the configuration, | ||
# and the client cuts the dataset horizontally without overlapping in sequence. | ||
# 2. The server sends the global model to the client. | ||
# 3. The client receives the global model (from the server) and returns the local | ||
# parameter difference to the server through local iterations. | ||
# 4. The server aggregates the difference between each client to update the model, | ||
# and then evaluates the current model performance | ||
# If the performance is not up to standard, repeat the process of 2, otherwise end. | ||
# print(__name__) | ||
if __name__ == '__main__': | ||
|
||
# load configuration file | ||
with open("conf.json", 'r') as f: | ||
conf = json.load(f) | ||
# Load dataset: train datasets, eval datasets | ||
train_datasets, eval_datasets = datasets.get_dataset("./data/", conf["type"]) | ||
# Start the server | ||
server = Server(conf, eval_datasets) | ||
# client List | ||
clients = [] | ||
# Add N clients to the list according to the conf configuration file | ||
for c in range(conf["no_models"]): | ||
clients.append(Client(conf, server.global_model, train_datasets, c)) | ||
# 开始时间 | ||
start_time = time.time() | ||
print("begin time:", start_time) | ||
|
||
# for the convenience of implementation, the implementation does not use network communication | ||
# to simulate the communication between the client and the server, but simulates it locally in a circular manner. | ||
print("begin global model training \n") | ||
# Global model training | ||
for e in range(conf["global_epochs"]): | ||
print("Global Epoch %d" % e) | ||
# Each training is to randomly sample k from the clients list for this round of training | ||
candidates = random.sample(clients, conf["k"]) | ||
# weight accumulation | ||
weight_accumulator = {} | ||
# Initialize empty model parameter weight_accumulator | ||
for name, params in server.global_model.state_dict().items(): | ||
# Generate a 0 matrix of the same size as the parameter matrix | ||
weight_accumulator[name] = torch.zeros_like(params) | ||
|
||
# Traverse clients, each client trains the model locally | ||
for c in candidates: | ||
diff = c.local_train(server.global_model) | ||
# print("client:", diff ) | ||
# Update the overall weight according to the client's parameter difference dictionary | ||
for name, params in server.global_model.state_dict().items(): | ||
weight_accumulator[name].add_(diff[name]) | ||
|
||
# model parameter aggregation | ||
server.model_aggregate(weight_accumulator) | ||
# model evaluation | ||
acc, loss = server.model_eval() | ||
|
||
print("Epoch %d, acc: %f, loss: %f\n" % (e, acc, loss)) | ||
|
||
end_time = time.time() | ||
print("end time:", end_time) | ||
|
||
end_time_calc = round(time.time() - start_time, 4) | ||
|
||
print('Execution time: {} seconds'.format(end_time_calc)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# 联邦学习 | ||
# | ||
|
||
import json | ||
import torch, random | ||
from fl_server import * | ||
from fl_client import * | ||
import models, datasets | ||
import time | ||
|
||
# The basic process of horizontal federated learning implemented: | ||
# | ||
# 1. The server generates the initialization model according to the configuration, | ||
# and the client cuts the dataset horizontally without overlapping in sequence. | ||
# 2. The server sends the global model to the client. | ||
# 3. The client receives the global model (from the server) and returns the local | ||
# parameter difference to the server through local iterations. | ||
# 4. The server aggregates the difference between each client to update the model, | ||
# and then evaluates the current model performance | ||
# If the performance is not up to standard, repeat the process of 2, otherwise end. | ||
# print(__name__) | ||
if __name__ == '__main__': | ||
|
||
# load configuration file | ||
with open("conf.json", 'r') as f: | ||
conf = json.load(f) | ||
# Load dataset: train datasets, eval datasets | ||
train_datasets, eval_datasets = datasets.get_dataset("./data/", conf["type"]) | ||
# Start the server | ||
server = Server(conf, eval_datasets) | ||
# client List | ||
clients = [] | ||
# Add N clients to the list according to the conf configuration file | ||
for c in range(conf["no_models"]): | ||
clients.append(Client(conf, server.global_model, train_datasets, c)) | ||
# 开始时间 | ||
start_time = time.time() | ||
print("begin time:", start_time) | ||
|
||
# for the convenience of implementation, the implementation does not use network communication | ||
# to simulate the communication between the client and the server, but simulates it locally in a circular manner. | ||
print("begin global model training \n") | ||
# Global model training | ||
for e in range(conf["global_epochs"]): | ||
print("Global Epoch %d" % e) | ||
# Each training is to randomly sample k from the clients list for this round of training | ||
candidates = random.sample(clients, conf["k"]) | ||
# weight accumulation | ||
weight_accumulator = {} | ||
# Initialize empty model parameter weight_accumulator | ||
for name, params in server.global_model.state_dict().items(): | ||
# Generate a 0 matrix of the same size as the parameter matrix | ||
weight_accumulator[name] = torch.zeros_like(params) | ||
|
||
# Traverse clients, each client trains the model locally | ||
for c in candidates: | ||
diff = c.local_train(server.global_model) | ||
# print("client:", diff ) | ||
# Update the overall weight according to the client's parameter difference dictionary | ||
for name, params in server.global_model.state_dict().items(): | ||
weight_accumulator[name].add_(diff[name]) | ||
|
||
# model parameter aggregation | ||
server.model_aggregate(weight_accumulator) | ||
# model evaluation | ||
acc, loss = server.model_eval() | ||
|
||
print("Epoch %d, acc: %f, loss: %f\n" % (e, acc, loss)) | ||
|
||
end_time = time.time() | ||
print("end time:", end_time) | ||
|
||
end_time_calc = round(time.time() - start_time, 4) | ||
|
||
print('Execution time: {} seconds'.format(end_time_calc)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import models, torch | ||
|
||
|
||
class Server(object): | ||
# Define the constructor to complete the initialization of configuration parameters | ||
def __init__(self, conf, eval_dataset): | ||
self.conf = conf | ||
self.global_model = models.get_model(self.conf["model_name"]) | ||
self.eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=self.conf["batch_size"], shuffle=True) | ||
|
||
# global aggregation model | ||
# weight_accumulator stores the upload parameter change value/difference value of each client | ||
def model_aggregate(self, weight_accumulator): | ||
# Traverse the server's global model | ||
for name, data in self.global_model.state_dict().items(): | ||
# update each layer multiplied by the learning rate 更新每一层乘上学习率 | ||
update_per_layer = weight_accumulator[name] * self.conf["lambda"] | ||
# cumulative sum | ||
if data.type() != update_per_layer.type(): | ||
# Because the type of update_per_layer is floatTensor, it will be converted to | ||
# LongTensor of the model (with a certain precision loss) | ||
data.add_(update_per_layer.to(torch.int64)) | ||
else: | ||
data.add_(update_per_layer) | ||
|
||
# evaluate function | ||
def model_eval(self): | ||
# Enable model evaluation mode (without modifying parameters) | ||
self.global_model.eval() | ||
total_loss = 0.0 | ||
correct = 0 | ||
dataset_size = 0 | ||
# Iterate over the evaluation data set | ||
for batch_id, batch in enumerate(self.eval_loader): | ||
data, target = batch | ||
# 获取所有的样本总量大小 | ||
dataset_size += data.size()[0] | ||
if torch.cuda.is_available(): | ||
data = data.cuda() | ||
target = target.cuda() | ||
output = self.global_model(data) | ||
# 聚合所有的损失 sum up batch loss | ||
# cross_entropy 交叉熵函数计算损失 | ||
total_loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item() | ||
# get the index of the max log-probability | ||
pred = output.data.max(1)[1] | ||
# 统计预测结果与真实标签target的匹配总个数 | ||
correct += pred.eq(target.data.view_as(pred)).cpu().sum().item() | ||
# Calculate accuracy | ||
acc = 100.0 * (float(correct) / float(dataset_size)) | ||
print("server acc", acc) | ||
# Calculate the loss value | ||
total_l = total_loss / dataset_size | ||
|
||
return acc, total_l |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
|
||
import torch | ||
from torchvision import models | ||
|
||
def get_model(name="vgg16", pretrained=True): | ||
if name == "resnet18": | ||
model = models.resnet18(pretrained=pretrained) | ||
elif name == "resnet50": | ||
model = models.resnet50(pretrained=pretrained) | ||
elif name == "densenet121": | ||
model = models.densenet121(pretrained=pretrained) | ||
elif name == "alexnet": | ||
model = models.alexnet(pretrained=pretrained) | ||
elif name == "vgg16": | ||
model = models.vgg16(pretrained=pretrained) | ||
elif name == "vgg19": | ||
model = models.vgg19(pretrained=pretrained) | ||
elif name == "inception_v3": | ||
model = models.inception_v3(pretrained=pretrained) | ||
elif name == "googlenet": | ||
model = models.googlenet(pretrained=pretrained) | ||
|
||
if torch.cuda.is_available(): | ||
return model.cuda() | ||
else: | ||
return model |