forked from FedML-AI/FedML
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request FedML-AI#1 from tremblerz/split_learning_distributed
Distributed version for Split Learning
- Loading branch information
Showing
5 changed files
with
365 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from mpi4py import MPI | ||
|
||
from fedml_api.distributed.split_nn.server import SplitNN_server | ||
from fedml_api.distributed.split_nn.client import SplitNN_client | ||
|
||
def SplitNN_init(): | ||
comm = MPI.COMM_WORLD | ||
process_id = comm.Get_rank() | ||
worker_number = comm.Get_size() | ||
return comm, process_id, worker_number | ||
|
||
def SplitNN_distributed(process_id, worker_number, device, comm, client_model, | ||
server_model, train_data_num, train_data_global, test_data_global, | ||
local_data_num, train_data_local, test_data_local, args): | ||
if process_id == 0: | ||
init_server(comm, server_model, worker_number, device) | ||
else: | ||
server_rank = 0 | ||
init_client(comm, client_model, worker_number, train_data_local, test_data_local, | ||
process_id, server_rank, args.epochs, device) | ||
|
||
def init_server(comm, server_model, worker_number, device): | ||
arg_dict = {"comm": comm, "model": server_model, "max_rank": worker_number - 1, | ||
"device": device} | ||
server = SplitNN_server(arg_dict) | ||
server.run() | ||
|
||
def init_client(comm, client_model, worker_number, train_data_local, test_data_local, | ||
process_id, server_rank, epochs, device): | ||
arg_dict = {"comm": comm, "trainloader": train_data_local, "testloader": test_data_local, | ||
"model": client_model, "rank": process_id, "server_rank": server_rank, | ||
"max_rank": worker_number - 1, "epochs": epochs, "device": device} | ||
client = SplitNN_client(arg_dict) | ||
client.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import logging | ||
import torch | ||
import torch.optim as optim | ||
|
||
from fedml_api.distributed.fedavg.message_define import MyMessage | ||
from fedml_core.distributed.client.client_manager import ClientMananger | ||
from fedml_core.distributed.communication import Message | ||
|
||
class SplitNN_client(): | ||
def __init__(self, args): | ||
self.comm = args["comm"] | ||
self.model = args["model"] | ||
self.trainloader = args["trainloader"] | ||
self.testloader = args["testloader"] | ||
self.rank = args["rank"] | ||
self.MAX_RANK = args["max_rank"] | ||
self.node_left = self.MAX_RANK if self.rank == 1 else self.rank - 1 | ||
self.node_right = 1 if self.rank == self.MAX_RANK else self.rank + 1 | ||
self.epoch_count = 0 | ||
self.MAX_EPOCH_PER_NODE = args["epochs"] | ||
self.SERVER_RANK = args["server_rank"] | ||
self.optimizer = optim.SGD(self.model.parameters(), lr=0.1, momentum=0.9, | ||
weight_decay=5e-4) | ||
|
||
self.trainloader = args["trainloader"] | ||
self.device = args["device"] | ||
|
||
def run(self): | ||
if self.rank == self.MAX_RANK: | ||
logging.info("sending semaphore from {} to {}".format(self.rank, | ||
self.node_right)) | ||
self.comm.send("semaphore", dest=self.node_right) | ||
|
||
while(True): | ||
signal = self.comm.recv(source=self.node_left) | ||
|
||
if signal == "semaphore": | ||
logging.info("Starting training at node {}".format(self.rank)) | ||
|
||
for batch_idx, (inputs, labels) in enumerate(self.trainloader): | ||
inputs, labels = inputs.to(self.device), labels.to(self.device) | ||
|
||
self.optimizer.zero_grad() | ||
|
||
intermed_tensor = self.model(inputs) | ||
self.comm.send([intermed_tensor, labels], dest=self.SERVER_RANK) | ||
grads = self.comm.recv(source=self.SERVER_RANK) | ||
|
||
intermed_tensor.backward(grads) | ||
self.optimizer.step() | ||
|
||
logging.info("Epoch over at node {}".format(self.rank)) | ||
del intermed_tensor, grads, inputs, labels | ||
torch.cuda.empty_cache() | ||
|
||
# Validation loss | ||
self.comm.send("validation", dest=self.SERVER_RANK) | ||
for batch_idx, (inputs, labels) in enumerate(self.testloader): | ||
inputs, labels = inputs.to(self.device), labels.to(self.device) | ||
intermed_tensor = self.model(inputs) | ||
self.comm.send([intermed_tensor, labels], dest=self.SERVER_RANK) | ||
|
||
del intermed_tensor, inputs, labels | ||
torch.cuda.empty_cache() | ||
|
||
self.epoch_count += 1 | ||
self.comm.send("semaphore", dest=self.node_right) | ||
# self.comm.send(model.state_dict(), dest=node_right) | ||
if self.epoch_count == self.MAX_EPOCH_PER_NODE: | ||
if self.rank == self.MAX_RANK: | ||
self.comm.send("over", dest=self.SERVER_RANK) | ||
self.comm.send("done", dest=self.SERVER_RANK) | ||
break | ||
self.comm.send("done", dest=self.SERVER_RANK) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
class MyMessage(object): | ||
""" | ||
message type definition | ||
""" | ||
# server to client | ||
MSG_TYPE_S2C_INIT_CONFIG = 1 | ||
MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT = 2 | ||
|
||
# client to server | ||
MSG_TYPE_C2S_SEND_MODEL_TO_SERVER = 3 | ||
MSG_TYPE_C2S_SEND_STATS_TO_SERVER = 4 | ||
|
||
MSG_ARG_KEY_TYPE = "msg_type" | ||
MSG_ARG_KEY_SENDER = "sender" | ||
MSG_ARG_KEY_RECEIVER = "receiver" | ||
|
||
""" | ||
message payload keywords definition | ||
""" | ||
MSG_ARG_KEY_NUM_SAMPLES = "num_samples" | ||
MSG_ARG_KEY_MODEL_PARAMS = "model_params" | ||
MSG_ARG_KEY_LOCAL_TRAINING_ACC = "local_training_acc" | ||
MSG_ARG_KEY_LOCAL_TRAINING_LOSS = "local_training_loss" | ||
MSG_ARG_KEY_LOCAL_TEST_ACC = "local_test_acc" | ||
MSG_ARG_KEY_LOCAL_TEST_LOSS = "local_test_loss" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import logging | ||
import datetime | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
|
||
from fedml_api.distributed.split_nn.message_define import MyMessage | ||
from fedml_core.distributed.communication import Message | ||
from fedml_core.distributed.server.server_manager import ServerMananger | ||
|
||
|
||
class SplitNN_server(): | ||
def __init__(self, args): | ||
self.comm = args["comm"] | ||
self.model = args["model"] | ||
self.MAX_RANK = args["max_rank"] | ||
self.active_node = 1 | ||
self.epoch = 0 | ||
self.batch_idx = 0 | ||
self.step = 0 | ||
self.log_step = 50 | ||
self.active_node = 1 | ||
self.phase = "train" | ||
self.val_loss = 0 | ||
self.total = 0 | ||
self.correct = 0 | ||
self.optimizer = optim.SGD(self.model.parameters(), lr=0.1, momentum=0.9, | ||
weight_decay=5e-4) | ||
self.criterion = nn.CrossEntropyLoss() | ||
|
||
def run(self, ): | ||
while(True): | ||
message = self.comm.recv(source=self.active_node) | ||
if message == "done": | ||
# not a precise estimate of validation loss | ||
self.val_loss /= self.step | ||
acc = self.correct / self.total | ||
logging.info("phase={} acc={} loss={} epoch={} and step={}" | ||
.format(self.phase, acc, self.loss.item(), self.epoch, self.step)) | ||
|
||
self.epoch += 1 | ||
self.active_node = (self.active_node % self.MAX_RANK) + 1 | ||
self.phase = "train" | ||
self.total = 0 | ||
self.correct = 0 | ||
self.val_loss = 0 | ||
self.step = 0 | ||
self.batch_idx = 0 | ||
logging.info("current active client is {}".format(self.active_node)) | ||
elif message == "over": | ||
logging.info("training over") | ||
break | ||
elif message == "validation": | ||
self.phase = "validation" | ||
self.step = 0 | ||
self.total = 0 | ||
self.correct = 0 | ||
else: | ||
if self.phase == "train": | ||
logging.debug("Server-Receive: client={}, index={}, time={}" | ||
.format(self.active_node, self.batch_idx, | ||
datetime.datetime.now())) | ||
self.optimizer.zero_grad() | ||
input_tensor, labels = message | ||
input_tensor.retain_grad() | ||
logits = self.model(input_tensor) | ||
_, predictions = logits.max(1) | ||
|
||
loss = self.criterion(logits, labels) | ||
self.loss = loss | ||
self.total += labels.size(0) | ||
self.correct += predictions.eq(labels).sum().item() | ||
|
||
if self.phase == "train": | ||
loss.backward() | ||
self.optimizer.step() | ||
self.comm.send(input_tensor.grad, dest=self.active_node) | ||
self.batch_idx += 1 | ||
|
||
self.step += 1 | ||
if self.step % self.log_step == 0 and self.phase == "train": | ||
acc = self.correct / self.total | ||
logging.info("phase={} acc={} loss={} epoch={} and step={}" | ||
.format("train", acc, loss.item(), self.epoch, self.step)) | ||
if self.phase == "validation": | ||
self.val_loss += loss.item() | ||
|
143 changes: 143 additions & 0 deletions
143
fedml_experiments/distributed/split_nn/main_split_nn.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
import setproctitle | ||
import argparse | ||
import logging | ||
import sys | ||
import os | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
|
||
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../"))) | ||
|
||
from fedml_api.data_preprocessing.cifar100.data_loader import load_partition_data_distributed_cifar100 | ||
from fedml_api.data_preprocessing.cinic10.data_loader import load_partition_data_distributed_cinic10 | ||
from fedml_api.data_preprocessing.cifar10.data_loader import load_partition_data_distributed_cifar10 | ||
|
||
from fedml_api.distributed.split_nn.SplitNNAPI import SplitNN_init, SplitNN_distributed | ||
from fedml_api.model.deep_neural_networks.mobilenet import mobilenet | ||
from fedml_api.model.deep_neural_networks.resnet import resnet56 | ||
|
||
def add_args(parser): | ||
""" | ||
parser : argparse.ArgumentParser | ||
return a parser added with args required by fit | ||
""" | ||
# Training settings | ||
parser.add_argument('--model', type=str, default='resnet56', metavar='N', | ||
help='neural network used in training') | ||
|
||
parser.add_argument('--dataset', type=str, default='cifar10', metavar='N', | ||
help='dataset used for training') | ||
|
||
parser.add_argument('--data_dir', type=str, default='./../../../data/cifar10', | ||
help='data directory') | ||
|
||
parser.add_argument('--partition_method', type=str, default='hetero', metavar='N', | ||
help='how to partition the dataset on local workers') | ||
|
||
parser.add_argument('--partition_alpha', type=float, default=0.5, metavar='PA', | ||
help='partition alpha (default: 0.5)') | ||
|
||
parser.add_argument('--client_number', type=int, default=16, metavar='NN', | ||
help='number of workers in a distributed cluster') | ||
|
||
parser.add_argument('--batch_size', type=int, default=64, metavar='N', | ||
help='input batch size for training (default: 64)') | ||
|
||
parser.add_argument('--lr', type=float, default=0.001, metavar='LR', | ||
help='learning rate (default: 0.001)') | ||
|
||
parser.add_argument('--wd', help='weight decay parameter;', type=float, default=0.001) | ||
|
||
parser.add_argument('--epochs', type=int, default=5, metavar='EP', | ||
help='how many epochs will be trained locally') | ||
|
||
parser.add_argument('--local_points', type=int, default=5000, metavar='LP', | ||
help='the approximate fixed number of data points we will have on each local worker') | ||
|
||
parser.add_argument('--comm_round', type=int, default=10, | ||
help='how many round of communications we shoud use') | ||
|
||
parser.add_argument('--frequency_of_the_test', type=int, default=1, | ||
help='the frequency of the algorithms') | ||
|
||
parser.add_argument('--gpu_server_num', type=int, default=1, | ||
help='gpu_server_num') | ||
|
||
parser.add_argument('--gpu_num_per_server', type=int, default=4, | ||
help='gpu_num_per_server') | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def init_training_device(process_ID, fl_worker_num, gpu_num_per_machine): | ||
# initialize the mapping from process ID to GPU ID: <process ID, GPU ID> | ||
if process_ID == 0: | ||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
return device | ||
process_gpu_dict = dict() | ||
for client_index in range(fl_worker_num): | ||
gpu_index = client_index % gpu_num_per_machine | ||
process_gpu_dict[client_index] = gpu_index | ||
|
||
logging.info(process_gpu_dict) | ||
device = torch.device("cuda:" + str(process_gpu_dict[process_ID - 1]) if torch.cuda.is_available() else "cpu") | ||
logging.info(device) | ||
return device | ||
|
||
|
||
if __name__ == "__main__": | ||
comm, process_id, worker_number = SplitNN_init() | ||
|
||
parser = argparse.ArgumentParser() | ||
args = add_args(parser) | ||
|
||
device = init_training_device(process_id, worker_number - 1, args.gpu_num_per_server) | ||
|
||
str_process_name = "SplitNN (distributed):" + str(process_id) | ||
setproctitle.setproctitle(str_process_name) | ||
|
||
logging.basicConfig(level=logging.INFO, | ||
format=str( | ||
process_id) + ' - %(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s', | ||
datefmt='%a, %d %b %Y %H:%M:%S') | ||
seed = 0 | ||
np.random.seed(seed) | ||
torch.manual_seed(worker_number) | ||
|
||
# load data | ||
if args.dataset == "cifar10": | ||
data_loader = load_partition_data_distributed_cifar10 | ||
elif args.dataset == "cifar100": | ||
data_loader = load_partition_data_distributed_cifar100 | ||
elif args.dataset == "cinic10": | ||
data_loader = load_partition_data_distributed_cinic10 | ||
else: | ||
data_loader = load_partition_data_distributed_cifar10 | ||
|
||
train_data_num, train_data_global,\ | ||
test_data_global, local_data_num, \ | ||
train_data_local, test_data_local, class_num = data_loader(process_id, args.dataset, args.data_dir, | ||
args.partition_method, args.partition_alpha, | ||
args.client_number, args.batch_size) | ||
|
||
# create the model | ||
model = None | ||
split_layer = 1 | ||
if args.model == "mobilenet": | ||
model = mobilenet(class_num=class_num) | ||
elif args.model == "resnet56": | ||
model = resnet56(class_num=class_num) | ||
|
||
fc_features = model.fc.in_features | ||
model.fc = nn.Sequential(nn.Flatten(), | ||
nn.Linear(fc_features, class_num)) | ||
#Split The model | ||
client_model = nn.Sequential(*nn.ModuleList(model.children())[:split_layer]) | ||
server_model = nn.Sequential(*nn.ModuleList(model.children())[split_layer:]) | ||
|
||
|
||
SplitNN_distributed(process_id, worker_number, device, comm, | ||
client_model, server_model, train_data_num, | ||
train_data_global, test_data_global, local_data_num, | ||
train_data_local, test_data_local, args) |