From bfda5eafbc8d671592367d533436f63dd0fba2b2 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 27 Jul 2020 03:44:13 +0000 Subject: [PATCH] fix arguments and first version of the working code --- fedml_api/distributed/split_nn/SplitNNAPI.py | 34 +++++ fedml_api/distributed/split_nn/client.py | 75 +++++++++ .../distributed/split_nn/message_define.py | 25 +++ fedml_api/distributed/split_nn/server.py | 88 +++++++++++ .../distributed/split_nn/main_split_nn.py | 143 ++++++++++++++++++ 5 files changed, 365 insertions(+) create mode 100644 fedml_api/distributed/split_nn/SplitNNAPI.py create mode 100644 fedml_api/distributed/split_nn/client.py create mode 100644 fedml_api/distributed/split_nn/message_define.py create mode 100644 fedml_api/distributed/split_nn/server.py create mode 100644 fedml_experiments/distributed/split_nn/main_split_nn.py diff --git a/fedml_api/distributed/split_nn/SplitNNAPI.py b/fedml_api/distributed/split_nn/SplitNNAPI.py new file mode 100644 index 0000000000..dbf47c2dbf --- /dev/null +++ b/fedml_api/distributed/split_nn/SplitNNAPI.py @@ -0,0 +1,34 @@ +from mpi4py import MPI + +from fedml_api.distributed.split_nn.server import SplitNN_server +from fedml_api.distributed.split_nn.client import SplitNN_client + +def SplitNN_init(): + comm = MPI.COMM_WORLD + process_id = comm.Get_rank() + worker_number = comm.Get_size() + return comm, process_id, worker_number + +def SplitNN_distributed(process_id, worker_number, device, comm, client_model, + server_model, train_data_num, train_data_global, test_data_global, + local_data_num, train_data_local, test_data_local, args): + if process_id == 0: + init_server(comm, server_model, worker_number, device) + else: + server_rank = 0 + init_client(comm, client_model, worker_number, train_data_local, test_data_local, + process_id, server_rank, args.epochs, device) + +def init_server(comm, server_model, worker_number, device): + arg_dict = {"comm": comm, "model": server_model, "max_rank": worker_number - 1, + "device": device} + server = SplitNN_server(arg_dict) + server.run() + +def init_client(comm, client_model, worker_number, train_data_local, test_data_local, + process_id, server_rank, epochs, device): + arg_dict = {"comm": comm, "trainloader": train_data_local, "testloader": test_data_local, + "model": client_model, "rank": process_id, "server_rank": server_rank, + "max_rank": worker_number - 1, "epochs": epochs, "device": device} + client = SplitNN_client(arg_dict) + client.run() diff --git a/fedml_api/distributed/split_nn/client.py b/fedml_api/distributed/split_nn/client.py new file mode 100644 index 0000000000..7f1b2f4cfd --- /dev/null +++ b/fedml_api/distributed/split_nn/client.py @@ -0,0 +1,75 @@ +import logging +import torch +import torch.optim as optim + +from fedml_api.distributed.fedavg.message_define import MyMessage +from fedml_core.distributed.client.client_manager import ClientMananger +from fedml_core.distributed.communication import Message + +class SplitNN_client(): + def __init__(self, args): + self.comm = args["comm"] + self.model = args["model"] + self.trainloader = args["trainloader"] + self.testloader = args["testloader"] + self.rank = args["rank"] + self.MAX_RANK = args["max_rank"] + self.node_left = self.MAX_RANK if self.rank == 1 else self.rank - 1 + self.node_right = 1 if self.rank == self.MAX_RANK else self.rank + 1 + self.epoch_count = 0 + self.MAX_EPOCH_PER_NODE = args["epochs"] + self.SERVER_RANK = args["server_rank"] + self.optimizer = optim.SGD(self.model.parameters(), lr=0.1, momentum=0.9, + weight_decay=5e-4) + + self.trainloader = args["trainloader"] + self.device = args["device"] + + def run(self): + if self.rank == self.MAX_RANK: + logging.info("sending semaphore from {} to {}".format(self.rank, + self.node_right)) + self.comm.send("semaphore", dest=self.node_right) + + while(True): + signal = self.comm.recv(source=self.node_left) + + if signal == "semaphore": + logging.info("Starting training at node {}".format(self.rank)) + + for batch_idx, (inputs, labels) in enumerate(self.trainloader): + inputs, labels = inputs.to(self.device), labels.to(self.device) + + self.optimizer.zero_grad() + + intermed_tensor = self.model(inputs) + self.comm.send([intermed_tensor, labels], dest=self.SERVER_RANK) + grads = self.comm.recv(source=self.SERVER_RANK) + + intermed_tensor.backward(grads) + self.optimizer.step() + + logging.info("Epoch over at node {}".format(self.rank)) + del intermed_tensor, grads, inputs, labels + torch.cuda.empty_cache() + + # Validation loss + self.comm.send("validation", dest=self.SERVER_RANK) + for batch_idx, (inputs, labels) in enumerate(self.testloader): + inputs, labels = inputs.to(self.device), labels.to(self.device) + intermed_tensor = self.model(inputs) + self.comm.send([intermed_tensor, labels], dest=self.SERVER_RANK) + + del intermed_tensor, inputs, labels + torch.cuda.empty_cache() + + self.epoch_count += 1 + self.comm.send("semaphore", dest=self.node_right) + # self.comm.send(model.state_dict(), dest=node_right) + if self.epoch_count == self.MAX_EPOCH_PER_NODE: + if self.rank == self.MAX_RANK: + self.comm.send("over", dest=self.SERVER_RANK) + self.comm.send("done", dest=self.SERVER_RANK) + break + self.comm.send("done", dest=self.SERVER_RANK) + diff --git a/fedml_api/distributed/split_nn/message_define.py b/fedml_api/distributed/split_nn/message_define.py new file mode 100644 index 0000000000..378a52d68e --- /dev/null +++ b/fedml_api/distributed/split_nn/message_define.py @@ -0,0 +1,25 @@ +class MyMessage(object): + """ + message type definition + """ + # server to client + MSG_TYPE_S2C_INIT_CONFIG = 1 + MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT = 2 + + # client to server + MSG_TYPE_C2S_SEND_MODEL_TO_SERVER = 3 + MSG_TYPE_C2S_SEND_STATS_TO_SERVER = 4 + + MSG_ARG_KEY_TYPE = "msg_type" + MSG_ARG_KEY_SENDER = "sender" + MSG_ARG_KEY_RECEIVER = "receiver" + + """ + message payload keywords definition + """ + MSG_ARG_KEY_NUM_SAMPLES = "num_samples" + MSG_ARG_KEY_MODEL_PARAMS = "model_params" + MSG_ARG_KEY_LOCAL_TRAINING_ACC = "local_training_acc" + MSG_ARG_KEY_LOCAL_TRAINING_LOSS = "local_training_loss" + MSG_ARG_KEY_LOCAL_TEST_ACC = "local_test_acc" + MSG_ARG_KEY_LOCAL_TEST_LOSS = "local_test_loss" diff --git a/fedml_api/distributed/split_nn/server.py b/fedml_api/distributed/split_nn/server.py new file mode 100644 index 0000000000..fad0c7eb90 --- /dev/null +++ b/fedml_api/distributed/split_nn/server.py @@ -0,0 +1,88 @@ +import logging +import datetime + +import torch +import torch.nn as nn +import torch.optim as optim + +from fedml_api.distributed.split_nn.message_define import MyMessage +from fedml_core.distributed.communication import Message +from fedml_core.distributed.server.server_manager import ServerMananger + + +class SplitNN_server(): + def __init__(self, args): + self.comm = args["comm"] + self.model = args["model"] + self.MAX_RANK = args["max_rank"] + self.active_node = 1 + self.epoch = 0 + self.batch_idx = 0 + self.step = 0 + self.log_step = 50 + self.active_node = 1 + self.phase = "train" + self.val_loss = 0 + self.total = 0 + self.correct = 0 + self.optimizer = optim.SGD(self.model.parameters(), lr=0.1, momentum=0.9, + weight_decay=5e-4) + self.criterion = nn.CrossEntropyLoss() + + def run(self, ): + while(True): + message = self.comm.recv(source=self.active_node) + if message == "done": + # not a precise estimate of validation loss + self.val_loss /= self.step + acc = self.correct / self.total + logging.info("phase={} acc={} loss={} epoch={} and step={}" + .format(self.phase, acc, self.loss.item(), self.epoch, self.step)) + + self.epoch += 1 + self.active_node = (self.active_node % self.MAX_RANK) + 1 + self.phase = "train" + self.total = 0 + self.correct = 0 + self.val_loss = 0 + self.step = 0 + self.batch_idx = 0 + logging.info("current active client is {}".format(self.active_node)) + elif message == "over": + logging.info("training over") + break + elif message == "validation": + self.phase = "validation" + self.step = 0 + self.total = 0 + self.correct = 0 + else: + if self.phase == "train": + logging.debug("Server-Receive: client={}, index={}, time={}" + .format(self.active_node, self.batch_idx, + datetime.datetime.now())) + self.optimizer.zero_grad() + input_tensor, labels = message + input_tensor.retain_grad() + logits = self.model(input_tensor) + _, predictions = logits.max(1) + + loss = self.criterion(logits, labels) + self.loss = loss + self.total += labels.size(0) + self.correct += predictions.eq(labels).sum().item() + + if self.phase == "train": + loss.backward() + self.optimizer.step() + self.comm.send(input_tensor.grad, dest=self.active_node) + self.batch_idx += 1 + + self.step += 1 + if self.step % self.log_step == 0 and self.phase == "train": + acc = self.correct / self.total + logging.info("phase={} acc={} loss={} epoch={} and step={}" + .format("train", acc, loss.item(), self.epoch, self.step)) + if self.phase == "validation": + self.val_loss += loss.item() + diff --git a/fedml_experiments/distributed/split_nn/main_split_nn.py b/fedml_experiments/distributed/split_nn/main_split_nn.py new file mode 100644 index 0000000000..6d57decf0d --- /dev/null +++ b/fedml_experiments/distributed/split_nn/main_split_nn.py @@ -0,0 +1,143 @@ +import setproctitle +import argparse +import logging +import sys +import os +import numpy as np +import torch +import torch.nn as nn + +sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../"))) + +from fedml_api.data_preprocessing.cifar100.data_loader import load_partition_data_distributed_cifar100 +from fedml_api.data_preprocessing.cinic10.data_loader import load_partition_data_distributed_cinic10 +from fedml_api.data_preprocessing.cifar10.data_loader import load_partition_data_distributed_cifar10 + +from fedml_api.distributed.split_nn.SplitNNAPI import SplitNN_init, SplitNN_distributed +from fedml_api.model.deep_neural_networks.mobilenet import mobilenet +from fedml_api.model.deep_neural_networks.resnet import resnet56 + +def add_args(parser): + """ + parser : argparse.ArgumentParser + return a parser added with args required by fit + """ + # Training settings + parser.add_argument('--model', type=str, default='resnet56', metavar='N', + help='neural network used in training') + + parser.add_argument('--dataset', type=str, default='cifar10', metavar='N', + help='dataset used for training') + + parser.add_argument('--data_dir', type=str, default='./../../../data/cifar10', + help='data directory') + + parser.add_argument('--partition_method', type=str, default='hetero', metavar='N', + help='how to partition the dataset on local workers') + + parser.add_argument('--partition_alpha', type=float, default=0.5, metavar='PA', + help='partition alpha (default: 0.5)') + + parser.add_argument('--client_number', type=int, default=16, metavar='NN', + help='number of workers in a distributed cluster') + + parser.add_argument('--batch_size', type=int, default=64, metavar='N', + help='input batch size for training (default: 64)') + + parser.add_argument('--lr', type=float, default=0.001, metavar='LR', + help='learning rate (default: 0.001)') + + parser.add_argument('--wd', help='weight decay parameter;', type=float, default=0.001) + + parser.add_argument('--epochs', type=int, default=5, metavar='EP', + help='how many epochs will be trained locally') + + parser.add_argument('--local_points', type=int, default=5000, metavar='LP', + help='the approximate fixed number of data points we will have on each local worker') + + parser.add_argument('--comm_round', type=int, default=10, + help='how many round of communications we shoud use') + + parser.add_argument('--frequency_of_the_test', type=int, default=1, + help='the frequency of the algorithms') + + parser.add_argument('--gpu_server_num', type=int, default=1, + help='gpu_server_num') + + parser.add_argument('--gpu_num_per_server', type=int, default=4, + help='gpu_num_per_server') + args = parser.parse_args() + return args + + +def init_training_device(process_ID, fl_worker_num, gpu_num_per_machine): + # initialize the mapping from process ID to GPU ID: + if process_ID == 0: + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + return device + process_gpu_dict = dict() + for client_index in range(fl_worker_num): + gpu_index = client_index % gpu_num_per_machine + process_gpu_dict[client_index] = gpu_index + + logging.info(process_gpu_dict) + device = torch.device("cuda:" + str(process_gpu_dict[process_ID - 1]) if torch.cuda.is_available() else "cpu") + logging.info(device) + return device + + +if __name__ == "__main__": + comm, process_id, worker_number = SplitNN_init() + + parser = argparse.ArgumentParser() + args = add_args(parser) + + device = init_training_device(process_id, worker_number - 1, args.gpu_num_per_server) + + str_process_name = "SplitNN (distributed):" + str(process_id) + setproctitle.setproctitle(str_process_name) + + logging.basicConfig(level=logging.INFO, + format=str( + process_id) + ' - %(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s', + datefmt='%a, %d %b %Y %H:%M:%S') + seed = 0 + np.random.seed(seed) + torch.manual_seed(worker_number) + + # load data + if args.dataset == "cifar10": + data_loader = load_partition_data_distributed_cifar10 + elif args.dataset == "cifar100": + data_loader = load_partition_data_distributed_cifar100 + elif args.dataset == "cinic10": + data_loader = load_partition_data_distributed_cinic10 + else: + data_loader = load_partition_data_distributed_cifar10 + + train_data_num, train_data_global,\ + test_data_global, local_data_num, \ + train_data_local, test_data_local, class_num = data_loader(process_id, args.dataset, args.data_dir, + args.partition_method, args.partition_alpha, + args.client_number, args.batch_size) + + # create the model + model = None + split_layer = 1 + if args.model == "mobilenet": + model = mobilenet(class_num=class_num) + elif args.model == "resnet56": + model = resnet56(class_num=class_num) + + fc_features = model.fc.in_features + model.fc = nn.Sequential(nn.Flatten(), + nn.Linear(fc_features, class_num)) + #Split The model + client_model = nn.Sequential(*nn.ModuleList(model.children())[:split_layer]) + server_model = nn.Sequential(*nn.ModuleList(model.children())[split_layer:]) + + + SplitNN_distributed(process_id, worker_number, device, comm, + client_model, server_model, train_data_num, + train_data_global, test_data_global, local_data_num, + train_data_local, test_data_local, args)