Skip to content

Commit

Permalink
code initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
liyu98 committed May 7, 2022
1 parent 41005f8 commit 468a07b
Show file tree
Hide file tree
Showing 8 changed files with 351 additions and 0 deletions.
25 changes: 25 additions & 0 deletions cifar10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

# Take a look at some samples from the dataset: each class shows some
def showPic(X_train, y_train):
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
num_classes = len(classes)
samples_per_class = 7
for y, cls in enumerate(classes):
idxs = np.flatnonzero(y_train == y)
# Randomly pick some from a category
idxs = np.random.choice(idxs, samples_per_class, replace=False)
for i, idx in enumerate(idxs):
plt_idx = i * num_classes + y + 1
plt.subplot(samples_per_class, num_classes, plt_idx)
plt.imshow(X_train[idx].astype('uint8'))
plt.axis('off')
if i == 0:
plt.title(cls)
plt.show()

if __name__ == '__main__':
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
showPic(x_train, y_train)
12 changes: 12 additions & 0 deletions conf.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"model_name" : "resnet18",
"no_models" : 10,
"type" : "cifar",
"global_epochs" : 20,
"local_epochs" : 3,
"k" : 9,
"batch_size" : 32,
"lr" : 0.001,
"momentum" : 0.0001,
"lambda" : 0.1
}
26 changes: 26 additions & 0 deletions datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch
from torchvision import datasets, transforms

def get_dataset(dir, name):
if name == 'cifar':
# Set two conversion formats
# transforms.Compose is a combination of multiple transforms (a list of transforms)
transform_train = transforms.Compose([
# transforms.RandomCrop: 切割中心点的位置随机选取
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# transforms.Normalize: 给定均值:(R,G,B) 方差:(R,G,B),将会把Tensor正则化
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_dataset = datasets.CIFAR10(dir, train=True, download=True,
transform=transform_train)
eval_dataset = datasets.CIFAR10(dir, train=False, transform=transform_test)

return train_dataset, eval_dataset
57 changes: 57 additions & 0 deletions fl_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import models, torch, copy


class Client(object):

def __init__(self, conf, model, train_dataset, id=-1):

self.conf = conf
# Client local model (usually transmitted by the server)
self.local_model = models.get_model(self.conf["model_name"])
self.client_id = id
self.train_dataset = train_dataset
# Split training set by ID
all_range = list(range(len(self.train_dataset)))
data_len = int(len(self.train_dataset) / self.conf['no_models'])
train_indices = all_range[id * data_len: (id + 1) * data_len]

self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=conf["batch_size"],
sampler=torch.utils.data.sampler.SubsetRandomSampler(
train_indices))

# model local training function
def local_train(self, model):
# Overall process: pull the model of the server and get it through training on some local datasets
for name, param in model.state_dict().items():
# The client first overwrites the local model with the global model issued by the server
self.local_model.state_dict()[name].copy_(param.clone())

# print(id(model))
# Define an optimization function for local model training
optimizer = torch.optim.SGD(self.local_model.parameters(), lr=self.conf['lr'],
momentum=self.conf['momentum'])
# print(id(self.local_model))
self.local_model.train()
for e in range(self.conf["local_epochs"]):

for batch_id, batch in enumerate(self.train_loader):
data, target = batch

if torch.cuda.is_available():
data = data.cuda()
target = target.cuda()

optimizer.zero_grad()
output = self.local_model(data)
loss = torch.nn.functional.cross_entropy(output, target)
loss.backward()

optimizer.step()
print("Epoch %d done." % e)
diff = dict()
for name, data in self.local_model.state_dict().items():
diff[name] = (data - model.state_dict()[name])
# print(diff[name])
print("Client %d local train done" % self.client_id)

return diff
75 changes: 75 additions & 0 deletions fl_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# 联邦学习
#

import json
import torch, random
from fl_server import *
from fl_client import *
import models, datasets
import time

# The basic process of horizontal federated learning implemented:
#
# 1. The server generates the initialization model according to the configuration,
# and the client cuts the dataset horizontally without overlapping in sequence.
# 2. The server sends the global model to the client.
# 3. The client receives the global model (from the server) and returns the local
# parameter difference to the server through local iterations.
# 4. The server aggregates the difference between each client to update the model,
# and then evaluates the current model performance
# If the performance is not up to standard, repeat the process of 2, otherwise end.
# print(__name__)
if __name__ == '__main__':

# load configuration file
with open("conf.json", 'r') as f:
conf = json.load(f)
# Load dataset: train datasets, eval datasets
train_datasets, eval_datasets = datasets.get_dataset("./data/", conf["type"])
# Start the server
server = Server(conf, eval_datasets)
# client List
clients = []
# Add N clients to the list according to the conf configuration file
for c in range(conf["no_models"]):
clients.append(Client(conf, server.global_model, train_datasets, c))
# 开始时间
start_time = time.time()
print("begin time:", start_time)

# for the convenience of implementation, the implementation does not use network communication
# to simulate the communication between the client and the server, but simulates it locally in a circular manner.
print("begin global model training \n")
# Global model training
for e in range(conf["global_epochs"]):
print("Global Epoch %d" % e)
# Each training is to randomly sample k from the clients list for this round of training
candidates = random.sample(clients, conf["k"])
# weight accumulation
weight_accumulator = {}
# Initialize empty model parameter weight_accumulator
for name, params in server.global_model.state_dict().items():
# Generate a 0 matrix of the same size as the parameter matrix
weight_accumulator[name] = torch.zeros_like(params)

# Traverse clients, each client trains the model locally
for c in candidates:
diff = c.local_train(server.global_model)
# print("client:", diff )
# Update the overall weight according to the client's parameter difference dictionary
for name, params in server.global_model.state_dict().items():
weight_accumulator[name].add_(diff[name])

# model parameter aggregation
server.model_aggregate(weight_accumulator)
# model evaluation
acc, loss = server.model_eval()

print("Epoch %d, acc: %f, loss: %f\n" % (e, acc, loss))

end_time = time.time()
print("end time:", end_time)

end_time_calc = round(time.time() - start_time, 4)

print('Execution time: {} seconds'.format(end_time_calc))
75 changes: 75 additions & 0 deletions fl_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# 联邦学习
#

import json
import torch, random
from fl_server import *
from fl_client import *
import models, datasets
import time

# The basic process of horizontal federated learning implemented:
#
# 1. The server generates the initialization model according to the configuration,
# and the client cuts the dataset horizontally without overlapping in sequence.
# 2. The server sends the global model to the client.
# 3. The client receives the global model (from the server) and returns the local
# parameter difference to the server through local iterations.
# 4. The server aggregates the difference between each client to update the model,
# and then evaluates the current model performance
# If the performance is not up to standard, repeat the process of 2, otherwise end.
# print(__name__)
if __name__ == '__main__':

# load configuration file
with open("conf.json", 'r') as f:
conf = json.load(f)
# Load dataset: train datasets, eval datasets
train_datasets, eval_datasets = datasets.get_dataset("./data/", conf["type"])
# Start the server
server = Server(conf, eval_datasets)
# client List
clients = []
# Add N clients to the list according to the conf configuration file
for c in range(conf["no_models"]):
clients.append(Client(conf, server.global_model, train_datasets, c))
# 开始时间
start_time = time.time()
print("begin time:", start_time)

# for the convenience of implementation, the implementation does not use network communication
# to simulate the communication between the client and the server, but simulates it locally in a circular manner.
print("begin global model training \n")
# Global model training
for e in range(conf["global_epochs"]):
print("Global Epoch %d" % e)
# Each training is to randomly sample k from the clients list for this round of training
candidates = random.sample(clients, conf["k"])
# weight accumulation
weight_accumulator = {}
# Initialize empty model parameter weight_accumulator
for name, params in server.global_model.state_dict().items():
# Generate a 0 matrix of the same size as the parameter matrix
weight_accumulator[name] = torch.zeros_like(params)

# Traverse clients, each client trains the model locally
for c in candidates:
diff = c.local_train(server.global_model)
# print("client:", diff )
# Update the overall weight according to the client's parameter difference dictionary
for name, params in server.global_model.state_dict().items():
weight_accumulator[name].add_(diff[name])

# model parameter aggregation
server.model_aggregate(weight_accumulator)
# model evaluation
acc, loss = server.model_eval()

print("Epoch %d, acc: %f, loss: %f\n" % (e, acc, loss))

end_time = time.time()
print("end time:", end_time)

end_time_calc = round(time.time() - start_time, 4)

print('Execution time: {} seconds'.format(end_time_calc))
55 changes: 55 additions & 0 deletions fl_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import models, torch


class Server(object):
# Define the constructor to complete the initialization of configuration parameters
def __init__(self, conf, eval_dataset):
self.conf = conf
self.global_model = models.get_model(self.conf["model_name"])
self.eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=self.conf["batch_size"], shuffle=True)

# global aggregation model
# weight_accumulator stores the upload parameter change value/difference value of each client
def model_aggregate(self, weight_accumulator):
# Traverse the server's global model
for name, data in self.global_model.state_dict().items():
# update each layer multiplied by the learning rate 更新每一层乘上学习率
update_per_layer = weight_accumulator[name] * self.conf["lambda"]
# cumulative sum
if data.type() != update_per_layer.type():
# Because the type of update_per_layer is floatTensor, it will be converted to
# LongTensor of the model (with a certain precision loss)
data.add_(update_per_layer.to(torch.int64))
else:
data.add_(update_per_layer)

# evaluate function
def model_eval(self):
# Enable model evaluation mode (without modifying parameters)
self.global_model.eval()
total_loss = 0.0
correct = 0
dataset_size = 0
# Iterate over the evaluation data set
for batch_id, batch in enumerate(self.eval_loader):
data, target = batch
# 获取所有的样本总量大小
dataset_size += data.size()[0]
if torch.cuda.is_available():
data = data.cuda()
target = target.cuda()
output = self.global_model(data)
# 聚合所有的损失 sum up batch loss
# cross_entropy 交叉熵函数计算损失
total_loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()
# get the index of the max log-probability
pred = output.data.max(1)[1]
# 统计预测结果与真实标签target的匹配总个数
correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
# Calculate accuracy
acc = 100.0 * (float(correct) / float(dataset_size))
print("server acc", acc)
# Calculate the loss value
total_l = total_loss / dataset_size

return acc, total_l
26 changes: 26 additions & 0 deletions models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

import torch
from torchvision import models

def get_model(name="vgg16", pretrained=True):
if name == "resnet18":
model = models.resnet18(pretrained=pretrained)
elif name == "resnet50":
model = models.resnet50(pretrained=pretrained)
elif name == "densenet121":
model = models.densenet121(pretrained=pretrained)
elif name == "alexnet":
model = models.alexnet(pretrained=pretrained)
elif name == "vgg16":
model = models.vgg16(pretrained=pretrained)
elif name == "vgg19":
model = models.vgg19(pretrained=pretrained)
elif name == "inception_v3":
model = models.inception_v3(pretrained=pretrained)
elif name == "googlenet":
model = models.googlenet(pretrained=pretrained)

if torch.cuda.is_available():
return model.cuda()
else:
return model

0 comments on commit 468a07b

Please sign in to comment.