Skip to content

Commit

Permalink
Merge pull request FedML-AI#1 from tremblerz/split_learning_distributed
Browse files Browse the repository at this point in the history
Distributed version for Split Learning
  • Loading branch information
Chaoyang He authored Jul 27, 2020
2 parents a40c8db + bfda5ea commit 976252e
Show file tree
Hide file tree
Showing 5 changed files with 365 additions and 0 deletions.
34 changes: 34 additions & 0 deletions fedml_api/distributed/split_nn/SplitNNAPI.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from mpi4py import MPI

from fedml_api.distributed.split_nn.server import SplitNN_server
from fedml_api.distributed.split_nn.client import SplitNN_client

def SplitNN_init():
comm = MPI.COMM_WORLD
process_id = comm.Get_rank()
worker_number = comm.Get_size()
return comm, process_id, worker_number

def SplitNN_distributed(process_id, worker_number, device, comm, client_model,
server_model, train_data_num, train_data_global, test_data_global,
local_data_num, train_data_local, test_data_local, args):
if process_id == 0:
init_server(comm, server_model, worker_number, device)
else:
server_rank = 0
init_client(comm, client_model, worker_number, train_data_local, test_data_local,
process_id, server_rank, args.epochs, device)

def init_server(comm, server_model, worker_number, device):
arg_dict = {"comm": comm, "model": server_model, "max_rank": worker_number - 1,
"device": device}
server = SplitNN_server(arg_dict)
server.run()

def init_client(comm, client_model, worker_number, train_data_local, test_data_local,
process_id, server_rank, epochs, device):
arg_dict = {"comm": comm, "trainloader": train_data_local, "testloader": test_data_local,
"model": client_model, "rank": process_id, "server_rank": server_rank,
"max_rank": worker_number - 1, "epochs": epochs, "device": device}
client = SplitNN_client(arg_dict)
client.run()
75 changes: 75 additions & 0 deletions fedml_api/distributed/split_nn/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import logging
import torch
import torch.optim as optim

from fedml_api.distributed.fedavg.message_define import MyMessage
from fedml_core.distributed.client.client_manager import ClientMananger
from fedml_core.distributed.communication import Message

class SplitNN_client():
def __init__(self, args):
self.comm = args["comm"]
self.model = args["model"]
self.trainloader = args["trainloader"]
self.testloader = args["testloader"]
self.rank = args["rank"]
self.MAX_RANK = args["max_rank"]
self.node_left = self.MAX_RANK if self.rank == 1 else self.rank - 1
self.node_right = 1 if self.rank == self.MAX_RANK else self.rank + 1
self.epoch_count = 0
self.MAX_EPOCH_PER_NODE = args["epochs"]
self.SERVER_RANK = args["server_rank"]
self.optimizer = optim.SGD(self.model.parameters(), lr=0.1, momentum=0.9,
weight_decay=5e-4)

self.trainloader = args["trainloader"]
self.device = args["device"]

def run(self):
if self.rank == self.MAX_RANK:
logging.info("sending semaphore from {} to {}".format(self.rank,
self.node_right))
self.comm.send("semaphore", dest=self.node_right)

while(True):
signal = self.comm.recv(source=self.node_left)

if signal == "semaphore":
logging.info("Starting training at node {}".format(self.rank))

for batch_idx, (inputs, labels) in enumerate(self.trainloader):
inputs, labels = inputs.to(self.device), labels.to(self.device)

self.optimizer.zero_grad()

intermed_tensor = self.model(inputs)
self.comm.send([intermed_tensor, labels], dest=self.SERVER_RANK)
grads = self.comm.recv(source=self.SERVER_RANK)

intermed_tensor.backward(grads)
self.optimizer.step()

logging.info("Epoch over at node {}".format(self.rank))
del intermed_tensor, grads, inputs, labels
torch.cuda.empty_cache()

# Validation loss
self.comm.send("validation", dest=self.SERVER_RANK)
for batch_idx, (inputs, labels) in enumerate(self.testloader):
inputs, labels = inputs.to(self.device), labels.to(self.device)
intermed_tensor = self.model(inputs)
self.comm.send([intermed_tensor, labels], dest=self.SERVER_RANK)

del intermed_tensor, inputs, labels
torch.cuda.empty_cache()

self.epoch_count += 1
self.comm.send("semaphore", dest=self.node_right)
# self.comm.send(model.state_dict(), dest=node_right)
if self.epoch_count == self.MAX_EPOCH_PER_NODE:
if self.rank == self.MAX_RANK:
self.comm.send("over", dest=self.SERVER_RANK)
self.comm.send("done", dest=self.SERVER_RANK)
break
self.comm.send("done", dest=self.SERVER_RANK)

25 changes: 25 additions & 0 deletions fedml_api/distributed/split_nn/message_define.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
class MyMessage(object):
"""
message type definition
"""
# server to client
MSG_TYPE_S2C_INIT_CONFIG = 1
MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT = 2

# client to server
MSG_TYPE_C2S_SEND_MODEL_TO_SERVER = 3
MSG_TYPE_C2S_SEND_STATS_TO_SERVER = 4

MSG_ARG_KEY_TYPE = "msg_type"
MSG_ARG_KEY_SENDER = "sender"
MSG_ARG_KEY_RECEIVER = "receiver"

"""
message payload keywords definition
"""
MSG_ARG_KEY_NUM_SAMPLES = "num_samples"
MSG_ARG_KEY_MODEL_PARAMS = "model_params"
MSG_ARG_KEY_LOCAL_TRAINING_ACC = "local_training_acc"
MSG_ARG_KEY_LOCAL_TRAINING_LOSS = "local_training_loss"
MSG_ARG_KEY_LOCAL_TEST_ACC = "local_test_acc"
MSG_ARG_KEY_LOCAL_TEST_LOSS = "local_test_loss"
88 changes: 88 additions & 0 deletions fedml_api/distributed/split_nn/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import logging
import datetime

import torch
import torch.nn as nn
import torch.optim as optim

from fedml_api.distributed.split_nn.message_define import MyMessage
from fedml_core.distributed.communication import Message
from fedml_core.distributed.server.server_manager import ServerMananger


class SplitNN_server():
def __init__(self, args):
self.comm = args["comm"]
self.model = args["model"]
self.MAX_RANK = args["max_rank"]
self.active_node = 1
self.epoch = 0
self.batch_idx = 0
self.step = 0
self.log_step = 50
self.active_node = 1
self.phase = "train"
self.val_loss = 0
self.total = 0
self.correct = 0
self.optimizer = optim.SGD(self.model.parameters(), lr=0.1, momentum=0.9,
weight_decay=5e-4)
self.criterion = nn.CrossEntropyLoss()

def run(self, ):
while(True):
message = self.comm.recv(source=self.active_node)
if message == "done":
# not a precise estimate of validation loss
self.val_loss /= self.step
acc = self.correct / self.total
logging.info("phase={} acc={} loss={} epoch={} and step={}"
.format(self.phase, acc, self.loss.item(), self.epoch, self.step))

self.epoch += 1
self.active_node = (self.active_node % self.MAX_RANK) + 1
self.phase = "train"
self.total = 0
self.correct = 0
self.val_loss = 0
self.step = 0
self.batch_idx = 0
logging.info("current active client is {}".format(self.active_node))
elif message == "over":
logging.info("training over")
break
elif message == "validation":
self.phase = "validation"
self.step = 0
self.total = 0
self.correct = 0
else:
if self.phase == "train":
logging.debug("Server-Receive: client={}, index={}, time={}"
.format(self.active_node, self.batch_idx,
datetime.datetime.now()))
self.optimizer.zero_grad()
input_tensor, labels = message
input_tensor.retain_grad()
logits = self.model(input_tensor)
_, predictions = logits.max(1)

loss = self.criterion(logits, labels)
self.loss = loss
self.total += labels.size(0)
self.correct += predictions.eq(labels).sum().item()

if self.phase == "train":
loss.backward()
self.optimizer.step()
self.comm.send(input_tensor.grad, dest=self.active_node)
self.batch_idx += 1

self.step += 1
if self.step % self.log_step == 0 and self.phase == "train":
acc = self.correct / self.total
logging.info("phase={} acc={} loss={} epoch={} and step={}"
.format("train", acc, loss.item(), self.epoch, self.step))
if self.phase == "validation":
self.val_loss += loss.item()

143 changes: 143 additions & 0 deletions fedml_experiments/distributed/split_nn/main_split_nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import setproctitle
import argparse
import logging
import sys
import os
import numpy as np
import torch
import torch.nn as nn

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../")))

from fedml_api.data_preprocessing.cifar100.data_loader import load_partition_data_distributed_cifar100
from fedml_api.data_preprocessing.cinic10.data_loader import load_partition_data_distributed_cinic10
from fedml_api.data_preprocessing.cifar10.data_loader import load_partition_data_distributed_cifar10

from fedml_api.distributed.split_nn.SplitNNAPI import SplitNN_init, SplitNN_distributed
from fedml_api.model.deep_neural_networks.mobilenet import mobilenet
from fedml_api.model.deep_neural_networks.resnet import resnet56

def add_args(parser):
"""
parser : argparse.ArgumentParser
return a parser added with args required by fit
"""
# Training settings
parser.add_argument('--model', type=str, default='resnet56', metavar='N',
help='neural network used in training')

parser.add_argument('--dataset', type=str, default='cifar10', metavar='N',
help='dataset used for training')

parser.add_argument('--data_dir', type=str, default='./../../../data/cifar10',
help='data directory')

parser.add_argument('--partition_method', type=str, default='hetero', metavar='N',
help='how to partition the dataset on local workers')

parser.add_argument('--partition_alpha', type=float, default=0.5, metavar='PA',
help='partition alpha (default: 0.5)')

parser.add_argument('--client_number', type=int, default=16, metavar='NN',
help='number of workers in a distributed cluster')

parser.add_argument('--batch_size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')

parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
help='learning rate (default: 0.001)')

parser.add_argument('--wd', help='weight decay parameter;', type=float, default=0.001)

parser.add_argument('--epochs', type=int, default=5, metavar='EP',
help='how many epochs will be trained locally')

parser.add_argument('--local_points', type=int, default=5000, metavar='LP',
help='the approximate fixed number of data points we will have on each local worker')

parser.add_argument('--comm_round', type=int, default=10,
help='how many round of communications we shoud use')

parser.add_argument('--frequency_of_the_test', type=int, default=1,
help='the frequency of the algorithms')

parser.add_argument('--gpu_server_num', type=int, default=1,
help='gpu_server_num')

parser.add_argument('--gpu_num_per_server', type=int, default=4,
help='gpu_num_per_server')
args = parser.parse_args()
return args


def init_training_device(process_ID, fl_worker_num, gpu_num_per_machine):
# initialize the mapping from process ID to GPU ID: <process ID, GPU ID>
if process_ID == 0:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
return device
process_gpu_dict = dict()
for client_index in range(fl_worker_num):
gpu_index = client_index % gpu_num_per_machine
process_gpu_dict[client_index] = gpu_index

logging.info(process_gpu_dict)
device = torch.device("cuda:" + str(process_gpu_dict[process_ID - 1]) if torch.cuda.is_available() else "cpu")
logging.info(device)
return device


if __name__ == "__main__":
comm, process_id, worker_number = SplitNN_init()

parser = argparse.ArgumentParser()
args = add_args(parser)

device = init_training_device(process_id, worker_number - 1, args.gpu_num_per_server)

str_process_name = "SplitNN (distributed):" + str(process_id)
setproctitle.setproctitle(str_process_name)

logging.basicConfig(level=logging.INFO,
format=str(
process_id) + ' - %(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
datefmt='%a, %d %b %Y %H:%M:%S')
seed = 0
np.random.seed(seed)
torch.manual_seed(worker_number)

# load data
if args.dataset == "cifar10":
data_loader = load_partition_data_distributed_cifar10
elif args.dataset == "cifar100":
data_loader = load_partition_data_distributed_cifar100
elif args.dataset == "cinic10":
data_loader = load_partition_data_distributed_cinic10
else:
data_loader = load_partition_data_distributed_cifar10

train_data_num, train_data_global,\
test_data_global, local_data_num, \
train_data_local, test_data_local, class_num = data_loader(process_id, args.dataset, args.data_dir,
args.partition_method, args.partition_alpha,
args.client_number, args.batch_size)

# create the model
model = None
split_layer = 1
if args.model == "mobilenet":
model = mobilenet(class_num=class_num)
elif args.model == "resnet56":
model = resnet56(class_num=class_num)

fc_features = model.fc.in_features
model.fc = nn.Sequential(nn.Flatten(),
nn.Linear(fc_features, class_num))
#Split The model
client_model = nn.Sequential(*nn.ModuleList(model.children())[:split_layer])
server_model = nn.Sequential(*nn.ModuleList(model.children())[split_layer:])


SplitNN_distributed(process_id, worker_number, device, comm,
client_model, server_model, train_data_num,
train_data_global, test_data_global, local_data_num,
train_data_local, test_data_local, args)

0 comments on commit 976252e

Please sign in to comment.