From 4eeff5746efb86886b104ced128babee63390fc2 Mon Sep 17 00:00:00 2001 From: qbc Date: Tue, 22 Nov 2022 18:15:44 +0800 Subject: [PATCH] Update xgb algo files (#439) - Refine each client's test data - Refine testing procedure - Add 3 datasets --- .../core/auxiliaries/data_builder.py | 2 +- federatedscope/core/data/utils.py | 2 +- .../vertical_fl/dataloader/dataloader.py | 54 ++++++ federatedscope/vertical_fl/dataset/abalone.py | 153 +++++++++++++++++ federatedscope/vertical_fl/dataset/adult.py | 14 +- federatedscope/vertical_fl/dataset/blog.py | 159 ++++++++++++++++++ federatedscope/vertical_fl/dataset/credit.py | 155 +++++++++++++++++ .../baseline/xgb_base_on_abalone.yaml | 33 ++++ .../xgb_base/baseline/xgb_base_on_adult.yaml | 4 +- .../baseline/xgb_base_on_blogfeedback.yaml | 33 ++++ .../xgb_base_on_givemesomecredit.yaml | 33 ++++ .../vertical_fl/xgb_base/worker/Test_base.py | 75 +++++---- .../vertical_fl/xgb_base/worker/XGBClient.py | 48 ++---- .../vertical_fl/xgb_base/worker/XGBServer.py | 49 ++---- 14 files changed, 708 insertions(+), 106 deletions(-) create mode 100644 federatedscope/vertical_fl/dataset/abalone.py create mode 100644 federatedscope/vertical_fl/dataset/blog.py create mode 100644 federatedscope/vertical_fl/dataset/credit.py create mode 100644 federatedscope/vertical_fl/xgb_base/baseline/xgb_base_on_abalone.yaml create mode 100644 federatedscope/vertical_fl/xgb_base/baseline/xgb_base_on_blogfeedback.yaml create mode 100644 federatedscope/vertical_fl/xgb_base/baseline/xgb_base_on_givemesomecredit.yaml diff --git a/federatedscope/core/auxiliaries/data_builder.py b/federatedscope/core/auxiliaries/data_builder.py index 33152c330..7e3e3cf92 100644 --- a/federatedscope/core/auxiliaries/data_builder.py +++ b/federatedscope/core/auxiliaries/data_builder.py @@ -28,7 +28,7 @@ 'subreddit', 'synthetic', 'ciao', 'epinions', '.*?vertical_fl_data.*?', '.*?movielens.*?', '.*?cikmcup.*?', 'graph_multi_domain.*?', 'cora', 'citeseer', 'pubmed', 'dblp_conf', 'dblp_org', 'csbm.*?', 'fb15k-237', - 'wn18', 'adult' + 'wn18', 'adult', 'abalone', 'credit', 'blog' ], # Dummy for FL dataset } DATA_TRANS_MAP = RegexInverseMap(TRANS_DATA_MAP, None) diff --git a/federatedscope/core/data/utils.py b/federatedscope/core/data/utils.py index 3f23863da..d3241ed58 100644 --- a/federatedscope/core/data/utils.py +++ b/federatedscope/core/data/utils.py @@ -82,7 +82,7 @@ def load_dataset(config): elif config.data.type.lower() == 'vertical_fl_data': from federatedscope.vertical_fl.dataloader import load_vertical_data dataset, modified_config = load_vertical_data(config, generate=True) - elif config.data.type.lower() in ['adult']: + elif config.data.type.lower() in ['adult', 'abalone', 'credit', 'blog']: from federatedscope.vertical_fl.dataloader import load_vertical_data dataset, modified_config = load_vertical_data(config, generate=False) elif 'movielens' in config.data.type.lower( diff --git a/federatedscope/vertical_fl/dataloader/dataloader.py b/federatedscope/vertical_fl/dataloader/dataloader.py index ca7eaa283..804574901 100644 --- a/federatedscope/vertical_fl/dataloader/dataloader.py +++ b/federatedscope/vertical_fl/dataloader/dataloader.py @@ -1,6 +1,10 @@ import numpy as np from federatedscope.vertical_fl.dataset.adult import Adult +from federatedscope.vertical_fl.dataset.abalone import Abalone +from federatedscope.vertical_fl.dataset.credit \ + import Credit +from federatedscope.vertical_fl.dataset.blog import Blog def load_vertical_data(config=None, generate=False): @@ -24,6 +28,8 @@ def load_vertical_data(config=None, generate=False): elif config.xgb_base.use: feature_partition = config.xgb_base.dims algo = 'xgb' + else: + raise ValueError('You must provide the data partition') if config.data.args: args = config.data.args[0] @@ -42,6 +48,54 @@ def load_vertical_data(config=None, generate=False): algo=algo) data = dataset.data return data, config + elif name == 'credit': + dataset = Credit(root=path, + name=name, + num_of_clients=config.federate.client_num, + feature_partition=feature_partition, + tr_frac=splits[0], + download=True, + seed=1234, + args=args, + algo=algo) + data = dataset.data + return data, config + elif name == 'adult': + dataset = Adult(root=path, + name=name, + num_of_clients=config.federate.client_num, + feature_partition=feature_partition, + tr_frac=splits[0], + download=True, + seed=1234, + args=args, + algo=algo) + data = dataset.data + return data, config + elif name == 'abalone': + dataset = Abalone(root=path, + name=name, + num_of_clients=config.federate.client_num, + feature_partition=feature_partition, + tr_frac=splits[0], + download=True, + seed=1234, + args=args, + algo=algo) + data = dataset.data + return data, config + elif name == 'blog': + dataset = Blog(root=path, + name=name, + num_of_clients=config.federate.client_num, + feature_partition=feature_partition, + tr_frac=splits[0], + download=True, + seed=1234, + args=args, + algo=algo) + data = dataset.data + return data, config elif generate: # generate toy data for running a vertical FL example INSTANCE_NUM = 1000 diff --git a/federatedscope/vertical_fl/dataset/abalone.py b/federatedscope/vertical_fl/dataset/abalone.py new file mode 100644 index 000000000..bfa0292f0 --- /dev/null +++ b/federatedscope/vertical_fl/dataset/abalone.py @@ -0,0 +1,153 @@ +import logging +import os +import os.path as osp + +import pandas as pd +from torchvision.datasets.utils import download_and_extract_archive + +logger = logging.getLogger(__name__) + + +class Abalone: + """ + Abalone Data Set + (https://archive.ics.uci.edu/ml/datasets/abalone) + Data Set Information: + Number of Instances: 4177 + Number of Attributes: 8 + + Predicting the age of abalone from physical measurements. + Given is the attribute name, attribute type, the measurement unit + and a brief description. + The number of rings is the value to predict: + either as a continuous value or as a classification problem. + + Name / Data Type / Measurement Unit / Description/ + + Sex / nominal / -- / M, F, and I (infant) + Length / continuous / mm / Longest shell measurement + Diameter / continuous / mm / perpendicular to length + Height / continuous / mm / with meat in shell + Whole weight / continuous / grams / whole abalone + Shucked weight / continuous / grams / weight of meat + Viscera weight / continuous / grams / gut weight (after bleeding) + Shell weight / continuous / grams / after being dried + Rings / integer / -- / +1.5 gives the age in years + + Arguments: + root (str): root path + name (str): name of dataset, ‘abalone’ or ‘xxx’ + num_of_clients(int): number of clients + feature_partition(list): the number of features + partitioned to each client + tr_frac (float): train set proportion for each task; default=0.8 + args (dict): set Ture or False to decide whether + to normalize or standardize the data or not, + e.g., {'normalization': False, 'standardization': False} + algo(str): the running model, 'lr' or 'xgb' + download (bool): indicator to download dataset + seed: a random seed + """ + base_folder = 'abalone' + url = 'https://federatedscope.oss-cn-beijing.aliyuncs.com/abalone.zip' + raw_file = 'abalone.data' + + def __init__(self, + root, + name, + num_of_clients, + feature_partition, + args, + algo=None, + tr_frac=0.8, + download=True, + seed=123): + self.root = root + self.name = name + self.num_of_clients = num_of_clients + self.feature_partition = feature_partition + self.tr_frac = tr_frac + self.seed = seed + self.args = args + self.algo = algo + self.data_dict = {} + self.data = {} + + if download: + self.download() + if not self._check_existence(): + raise RuntimeError("Dataset not found or corrupted." + + "You can use download=True to download it") + + self._get_data() + self._partition_data() + + def _get_data(self): + fpath = os.path.join(self.root, self.base_folder) + file = osp.join(fpath, self.raw_file) + data = self._read_raw(file) + data = self._process(data) + train_num = int(self.tr_frac * len(data)) + self.data_dict['train'] = data[:train_num] + self.data_dict['test'] = data[train_num:] + + def _read_raw(self, file_path): + data = pd.read_csv(file_path, header=None) + return data + + def _process(self, data): + data[0] = data[0].replace({'F': 2, 'M': 1, 'I': 0}) + data = data.values + return data + + def _check_existence(self): + fpath = os.path.join(self.root, self.base_folder, self.raw_file) + return osp.exists(fpath) + + def download(self): + if self._check_existence(): + logger.info("Files already exist") + return + download_and_extract_archive(self.url, + os.path.join(self.root, self.base_folder), + filename=self.url.split('/')[-1]) + + def _partition_data(self): + + x = self.data_dict['train'][:, :-1] + y = self.data_dict['train'][:, -1] + + test_data = { + 'x': self.data_dict['test'][:, :-1], + 'y': self.data_dict['test'][:, -1] + } + + test_x = test_data['x'] + test_y = test_data['y'] + + self.data = dict() + for i in range(self.num_of_clients + 1): + self.data[i] = dict() + if i == 0: + self.data[0]['train'] = None + self.data[0]['test'] = test_data + elif i == 1: + self.data[1]['train'] = {'x': x[:, :self.feature_partition[0]]} + self.data[1]['test'] = { + 'x': test_x[:, :self.feature_partition[0]] + } + else: + self.data[i]['train'] = { + 'x': x[:, + self.feature_partition[i - + 2]:self.feature_partition[i - + 1]] + } + self.data[i]['test'] = { + 'x': test_x[:, self.feature_partition[i - 2]:self. + feature_partition[i - 1]] + } + self.data[i]['val'] = None + + self.data[self.num_of_clients]['train']['y'] = y[:] + self.data[self.num_of_clients]['test']['y'] = test_y[:] diff --git a/federatedscope/vertical_fl/dataset/adult.py b/federatedscope/vertical_fl/dataset/adult.py index 2db87e65d..43c4de70b 100644 --- a/federatedscope/vertical_fl/dataset/adult.py +++ b/federatedscope/vertical_fl/dataset/adult.py @@ -15,6 +15,8 @@ class Adult: (https://archive.ics.uci.edu/ml/datasets/adult) Fields The dataset contains 15 columns + Training set: 'adult.data', 32561 instances + Testing set: 'adult.test', 16281 instances Target filed: Income -- The income is divide into two classes: <=50K and >50K Number of attributes: 14 @@ -30,7 +32,7 @@ class Adult: args (dict): set Ture or False to decide whether to normalize or standardize the data or not, e.g., {'normalization': False, 'standardization': False} - model(str): the running model, 'lr' or 'xgb' + algo(str): the running model, 'lr' or 'xgb' download (bool): indicator to download dataset seed: a random seed """ @@ -146,8 +148,12 @@ def _partition_data(self, train_set, test_set): self.data[i] = dict() if i == 0: self.data[0]['train'] = None + self.data[0]['test'] = test_data elif i == 1: self.data[1]['train'] = {'x': x[:, :self.feature_partition[0]]} + self.data[1]['test'] = { + 'x': test_x[:, :self.feature_partition[0]] + } else: self.data[i]['train'] = { 'x': x[:, @@ -155,10 +161,14 @@ def _partition_data(self, train_set, test_set): 2]:self.feature_partition[i - 1]] } + self.data[i]['test'] = { + 'x': test_x[:, self.feature_partition[i - 2]:self. + feature_partition[i - 1]] + } self.data[i]['val'] = None - self.data[i]['test'] = test_data self.data[self.num_of_clients]['train']['y'] = y[:] + self.data[self.num_of_clients]['test']['y'] = test_y[:] def _check_existence(self, file): fpath = os.path.join(self.root, self.base_folder, file) diff --git a/federatedscope/vertical_fl/dataset/blog.py b/federatedscope/vertical_fl/dataset/blog.py new file mode 100644 index 000000000..ef1b366ef --- /dev/null +++ b/federatedscope/vertical_fl/dataset/blog.py @@ -0,0 +1,159 @@ +import glob +import logging +import os +import os.path as osp + +import numpy as np +import pandas as pd + +from torchvision.datasets.utils import download_and_extract_archive + +logger = logging.getLogger(__name__) + + +class Blog: + """ + BlogFeedback Data Set + (https://archive.ics.uci.edu/ml/datasets/BlogFeedback) + + Data Set Information: + This data originates from blog posts. The raw HTML-documents + of the blog posts were crawled and processed. + The prediction task associated with the data is the prediction + of the number of comments in the upcoming 24 hours. In order + to simulate this situation, we choose a basetime (in the past) + and select the blog posts that were published at most + 72 hours before the selected base date/time. Then, we calculate + all the features of the selected blog posts from the information + that was available at the basetime, therefore each instance + corresponds to a blog post. The target is the number of + comments that the blog post received in the next 24 hours + relative to the basetime. + + Number of Instances: 60021 + Number of Attributes: 281, the last one is the number of comments + in the next 24 hours + Training set: 'blogData_train.csv', 52397 instances + Testing set: 'blogData_test*.csv', 60 files, 7624 instances totally + + Arguments: + root (str): root path + name (str): name of dataset, ‘blog’ or ‘xxx’ + num_of_clients(int): number of clients + feature_partition(list): the number of features + partitioned to each client + tr_frac (float): train set proportion for each task; default=0.8 + args (dict): set Ture or False to decide whether + to normalize or standardize the data or not, + e.g., {'normalization': False, 'standardization': False} + algo(str): the running model, 'lr' or 'xgb' + download (bool): indicator to download dataset + seed: a random seed + """ + base_folder = 'blogfeedback' + url = 'https://federatedscope.oss-cn-beijing.aliyuncs.com/BlogFeedback.zip' + raw_file = 'BlogFeedback.zip' + + def __init__(self, + root, + name, + num_of_clients, + feature_partition, + args, + algo=None, + tr_frac=0.8, + download=True, + seed=123): + super(Blog, self).__init__() + self.root = root + self.name = name + self.num_of_clients = num_of_clients + self.tr_frac = tr_frac + self.feature_partition = feature_partition + self.seed = seed + self.args = args + self.algo = algo + self.data_dict = {} + self.data = {} + + if download: + self.download() + if not self._check_existence(): + raise RuntimeError("Dataset not found or corrupted." + + "You can use download=True to download it") + + self._get_data() + self._partition_data() + + def _get_data(self): + fpath = os.path.join(self.root, self.base_folder) + train_file = osp.join(fpath, 'blogData_train.csv') + train_data = self._read_raw(train_file) + test_files = glob.glob(osp.join(fpath, "blogData_test*.csv")) + test_files.sort() + + flag = 0 + for f in test_files: + f_data = self._read_raw(f) + if flag == 0: + test_data = f_data + flag = 1 + else: + test_data = np.concatenate((test_data, f_data), axis=0) + + self.data_dict['train'] = train_data + self.data_dict['test'] = test_data + + def _read_raw(self, file_path): + data = pd.read_csv(file_path, header=None, usecols=list(range(281))) + data = data.values + return data + + def _check_existence(self): + fpath = os.path.join(self.root, self.base_folder, self.raw_file) + return osp.exists(fpath) + + def download(self): + if self._check_existence(): + logger.info("Files already exist") + return + download_and_extract_archive(self.url, + os.path.join(self.root, self.base_folder), + filename=self.url.split('/')[-1]) + + def _partition_data(self): + x = self.data_dict['train'][:, :self.feature_partition[-1]] + y = self.data_dict['train'][:, self.feature_partition[-1]] + test_data = dict() + test_data['x'] = self.data_dict['test'][:, :self.feature_partition[-1]] + test_data['y'] = self.data_dict['test'][:, self.feature_partition[-1]] + + test_x = test_data['x'] + test_y = test_data['y'] + + self.data = dict() + for i in range(self.num_of_clients + 1): + self.data[i] = dict() + if i == 0: + self.data[0]['train'] = None + self.data[0]['test'] = test_data + elif i == 1: + self.data[1]['train'] = {'x': x[:, :self.feature_partition[0]]} + self.data[1]['test'] = { + 'x': test_x[:, :self.feature_partition[0]] + } + else: + self.data[i]['train'] = { + 'x': x[:, + self.feature_partition[i - + 2]:self.feature_partition[i - + 1]] + } + self.data[i]['test'] = { + 'x': test_x[:, self.feature_partition[i - 2]:self. + feature_partition[i - 1]] + } + self.data[i]['val'] = None + + self.data[self.num_of_clients]['train']['y'] = y[:] + self.data[self.num_of_clients]['test']['y'] = test_y[:] diff --git a/federatedscope/vertical_fl/dataset/credit.py b/federatedscope/vertical_fl/dataset/credit.py new file mode 100644 index 000000000..26c23d4a3 --- /dev/null +++ b/federatedscope/vertical_fl/dataset/credit.py @@ -0,0 +1,155 @@ +import logging +import os +import os.path as osp + +import numpy as np +import pandas as pd +from torchvision.datasets.utils import download_and_extract_archive + +logger = logging.getLogger(__name__) + + +class Credit: + """ + Give Me Some Credit Data Set + (https://www.kaggle.com/competitions/GiveMeSomeCredit) + Data Set: cs-training.csv, 150000 instances and 12 attributes + The first attribute is the user ID which we do not need, the second + attribute is the label, determining whether a loan should be granted. + + Arguments: + root (str): root path + name (str): name of dataset, ‘credit’ or ‘xxx’ + num_of_clients(int): number of clients + feature_partition(list): the number of features + partitioned to each client + tr_frac (float): train set proportion for each task; default=0.8 + args (dict): set Ture or False to decide whether + to normalize or standardize the data or not, + e.g., {'normalization': False, 'standardization': False} + algo(str): the running model, 'lr' or 'xgb' + download (bool): indicator to download dataset + seed: a random seed + """ + base_folder = 'givemesomecredit' + url = 'https://federatedscope.oss-cn-beijing.aliyuncs.com/cs-training.zip' + raw_file = 'cs-training.csv' + + def __init__(self, + root, + name, + num_of_clients, + feature_partition, + args, + algo=None, + tr_frac=0.8, + download=True, + seed=123): + super(Credit, self).__init__() + self.root = root + self.name = name + self.num_of_clients = num_of_clients + self.feature_partition = feature_partition + self.tr_frac = tr_frac + self.seed = seed + self.args = args + self.algo = algo + self.data_dict = {} + self.data = {} + + if download: + self.download() + if not self._check_existence(): + raise RuntimeError("Dataset not found or corrupted." + + "You can use download=True to download it") + + self._get_data() + self._partition_data() + + def _get_data(self): + fpath = os.path.join(self.root, self.base_folder) + file = osp.join(fpath, self.raw_file) + data = self._read_raw(file) + data = data[:, 1:] + + # the following codes are used to choose balanced data + # they may be removed later + # ''' + sample_size = 150000 + + def balance_sample(sample_size, y): + y_ones_idx = (y == 1).nonzero()[0] + y_ones_idx = np.random.choice(y_ones_idx, + size=int(sample_size / 2)) + y_zeros_idx = (y == 0).nonzero()[0] + y_zeros_idx = np.random.choice(y_zeros_idx, + size=int(sample_size / 2)) + + y_index = np.concatenate([y_zeros_idx, y_ones_idx], axis=0) + np.random.shuffle(y_index) + return y_index + + sample_idx = balance_sample(sample_size, data[:, 0]) + data = data[sample_idx] + # ''' + + train_num = int(self.tr_frac * len(data)) + + self.data_dict['train'] = data[:train_num] + self.data_dict['test'] = data[train_num:] + + def _read_raw(self, file_path): + data = pd.read_csv(file_path) + data = data.values + return data + + def _check_existence(self): + fpath = os.path.join(self.root, self.base_folder, self.raw_file) + return osp.exists(fpath) + + def download(self): + if self._check_existence(): + logger.info("Files already exist") + return + download_and_extract_archive(self.url, + os.path.join(self.root, self.base_folder), + filename=self.url.split('/')[-1]) + + def _partition_data(self): + + x = self.data_dict['train'][:, 1:] + y = self.data_dict['train'][:, 0] + + test_data = { + 'x': self.data_dict['test'][:, 1:], + 'y': self.data_dict['test'][:, 0] + } + test_x = test_data['x'] + test_y = test_data['y'] + + self.data = dict() + for i in range(self.num_of_clients + 1): + self.data[i] = dict() + if i == 0: + self.data[0]['train'] = None + self.data[0]['test'] = test_data + elif i == 1: + self.data[1]['train'] = {'x': x[:, :self.feature_partition[0]]} + self.data[1]['test'] = { + 'x': test_x[:, :self.feature_partition[0]] + } + else: + self.data[i]['train'] = { + 'x': x[:, + self.feature_partition[i - + 2]:self.feature_partition[i - + 1]] + } + self.data[i]['test'] = { + 'x': test_x[:, self.feature_partition[i - 2]:self. + feature_partition[i - 1]] + } + self.data[i]['val'] = None + + self.data[self.num_of_clients]['train']['y'] = y + self.data[self.num_of_clients]['test']['y'] = test_y[:] diff --git a/federatedscope/vertical_fl/xgb_base/baseline/xgb_base_on_abalone.yaml b/federatedscope/vertical_fl/xgb_base/baseline/xgb_base_on_abalone.yaml new file mode 100644 index 000000000..c320418b7 --- /dev/null +++ b/federatedscope/vertical_fl/xgb_base/baseline/xgb_base_on_abalone.yaml @@ -0,0 +1,33 @@ +use_gpu: False +device: 0 +backend: torch +federate: + mode: standalone + client_num: 2 +model: + type: lr +data: + root: data/ + type: abalone + batch_size: 4000 + splits: [0.8, 0.2] +dataloader: + type: raw +criterion: + type: Regression +trainer: + type: none +train: + optimizer: + bin_num: 1000 + lambda_: 0.1 + gamma: 0 + num_of_trees: 10 + max_tree_depth: 3 +xgb_base: + use: True + use_bin: True + dims: [4, 8] +eval: + freq: 5 + best_res_update_round_wise_key: test_loss \ No newline at end of file diff --git a/federatedscope/vertical_fl/xgb_base/baseline/xgb_base_on_adult.yaml b/federatedscope/vertical_fl/xgb_base/baseline/xgb_base_on_adult.yaml index 11d8c5196..49f64aa92 100644 --- a/federatedscope/vertical_fl/xgb_base/baseline/xgb_base_on_adult.yaml +++ b/federatedscope/vertical_fl/xgb_base/baseline/xgb_base_on_adult.yaml @@ -22,12 +22,12 @@ train: bin_num: 100 lambda_: 0.1 gamma: 0 - num_of_trees: 5 + num_of_trees: 10 max_tree_depth: 3 xgb_base: use: True use_bin: True dims: [7, 14] eval: - freq: 5 + freq: 3 best_res_update_round_wise_key: test_loss \ No newline at end of file diff --git a/federatedscope/vertical_fl/xgb_base/baseline/xgb_base_on_blogfeedback.yaml b/federatedscope/vertical_fl/xgb_base/baseline/xgb_base_on_blogfeedback.yaml new file mode 100644 index 000000000..7a34f92fd --- /dev/null +++ b/federatedscope/vertical_fl/xgb_base/baseline/xgb_base_on_blogfeedback.yaml @@ -0,0 +1,33 @@ +use_gpu: False +device: 0 +backend: torch +federate: + mode: standalone + client_num: 2 +model: + type: lr +data: + root: data/ + type: blog + batch_size: 8000 + splits: [1.0, 0.0] +dataloader: + type: raw +criterion: + type: Regression +trainer: + type: none +train: + optimizer: + bin_num: 1000 + lambda_: 10 + gamma: 0 + num_of_trees: 9 + max_tree_depth: 3 +xgb_base: + use: True + use_bin: True + dims: [10, 20] +eval: + freq: 3 + best_res_update_round_wise_key: test_loss \ No newline at end of file diff --git a/federatedscope/vertical_fl/xgb_base/baseline/xgb_base_on_givemesomecredit.yaml b/federatedscope/vertical_fl/xgb_base/baseline/xgb_base_on_givemesomecredit.yaml new file mode 100644 index 000000000..b40838cd1 --- /dev/null +++ b/federatedscope/vertical_fl/xgb_base/baseline/xgb_base_on_givemesomecredit.yaml @@ -0,0 +1,33 @@ +use_gpu: False +device: 0 +backend: torch +federate: + mode: standalone + client_num: 2 +model: + type: lr +data: + root: data/ + type: credit + batch_size: 2000 + splits: [0.8, 0.2] +dataloader: + type: raw +criterion: + type: CrossEntropyLoss +trainer: + type: none +train: + optimizer: + bin_num: 100 + lambda_: 0.1 + gamma: 0 + num_of_trees: 10 + max_tree_depth: 3 +xgb_base: + use: True + use_bin: True + dims: [5, 10] +eval: + freq: 3 + best_res_update_round_wise_key: test_loss \ No newline at end of file diff --git a/federatedscope/vertical_fl/xgb_base/worker/Test_base.py b/federatedscope/vertical_fl/xgb_base/worker/Test_base.py index 97f3baff7..66734a5f5 100644 --- a/federatedscope/vertical_fl/xgb_base/worker/Test_base.py +++ b/federatedscope/vertical_fl/xgb_base/worker/Test_base.py @@ -11,55 +11,72 @@ class Test_base: def __init__(self, obj): self.client = obj - self.client.register_handlers('test_data', - self.callback_func_for_test_data) - self.client.register_handlers('test_value', - self.callback_func_for_test_value) self.client.register_handlers( 'split_lr_for_test_data', self.callback_func_for_split_lr_for_test_data) self.client.register_handlers('LR', self.callback_func_for_LR) - def callback_func_for_test_value(self, message: Message): - self.test_y = message.content - - def callback_func_for_test_data(self, message: Message): - self.test_x = message.content - - if self.client.own_label: - self.test_z = np.zeros(self.test_x.shape[0]) - - tree_num = 0 - self.test_for_root(tree_num) + def evaluation(self): + loss = self.client.ls.loss(self.client.test_y, self.client.test_result) + if self.client.criterion_type == 'CrossEntropyLoss': + metric = self.client.ls.metric(self.client.test_y, + self.client.test_result) + metrics = { + 'test_loss': loss, + 'test_acc': metric[1], + 'test_total': len(self.client.test_y) + } + else: + metrics = { + 'test_loss': loss, + 'test_total': len(self.client.test_y) + } + return metrics def test_for_root(self, tree_num): node_num = 0 self.client.tree_list[tree_num][node_num].indicator = np.ones( - self.test_x.shape[0]) + self.client.test_x.shape[0]) self.test_for_node(tree_num, node_num) def test_for_node(self, tree_num, node_num): if node_num >= 2**self.client.max_tree_depth - 1: - if tree_num + 1 == self.client.num_of_trees: - metric = self.client.ls.metric(self.test_y, self.test_z) - loss = self.client.ls.loss(self.test_y, self.test_z) - metrics = { - 'test_loss': loss, - 'test_acc': metric[1], - 'test_total': len(self.test_y) - } + if tree_num + 1 < self.client.num_of_trees: + # TODO: add feedback during training + self.client.state += 1 + logger.info( + f'----------- Starting a new training round (Round ' + f'#{self.client.state}) -------------') + # build the next tree + self.client.fs.compute_for_root(tree_num + 1) + else: + metrics = self.evaluation() self.client.comm_manager.send( Message(msg_type='test_result', sender=self.client.ID, state=self.client.state, receiver=self.client.server_id, - content=metrics)) + content=(tree_num, metrics))) - else: - self.test_for_root(tree_num + 1) + self.client.comm_manager.send( + Message(msg_type='send_feature_importance', + sender=self.client.ID, + state=self.client.state, + receiver=[ + each for each in list( + self.client.comm_manager.neighbors.keys()) + if each != self.client.server_id + ], + content=None)) + self.client.comm_manager.send( + Message(msg_type='feature_importance', + sender=self.client.ID, + state=self.client.state, + receiver=self.client.server_id, + content=self.client.feature_importance)) elif self.client.tree_list[tree_num][node_num].weight: - self.test_z += self.client.tree_list[tree_num][ + self.client.test_result += self.client.tree_list[tree_num][ node_num].indicator * self.client.tree_list[tree_num][ node_num].weight self.test_for_node(tree_num, node_num + 1) @@ -78,7 +95,7 @@ def callback_func_for_split_lr_for_test_data(self, message: Message): tree_num, node_num = message.content feature_idx = self.client.tree_list[tree_num][node_num].feature_idx feature_value = self.client.tree_list[tree_num][node_num].feature_value - L, R = self.client.split_for_lr(self.test_x[:, feature_idx], + L, R = self.client.split_for_lr(self.client.test_x[:, feature_idx], feature_value) self.client.comm_manager.send( Message(msg_type='LR', diff --git a/federatedscope/vertical_fl/xgb_base/worker/XGBClient.py b/federatedscope/vertical_fl/xgb_base/worker/XGBClient.py index f22475a95..a62d57774 100644 --- a/federatedscope/vertical_fl/xgb_base/worker/XGBClient.py +++ b/federatedscope/vertical_fl/xgb_base/worker/XGBClient.py @@ -48,6 +48,13 @@ def __init__(self, self.data = data self.own_label = ('y' in self.data['train']) + + self.test_x = self.data['test']['x'] + if self.own_label: + self.test_y = self.data['test']['y'] + + self.test_result = np.zeros(self.test_x.shape[0]) + self.y_hat = None self.y = None self.num_of_parties = config.federate.client_num @@ -71,8 +78,6 @@ def __init__(self, self.feature_order = [0] * self.my_num_of_feature self.feature_importance = [0] * self.my_num_of_feature - # self.ss = AdditiveSecretSharing(shared_party_num=self.num_of_parties) - # self.ns = Node_split() # self.fs = Feature_sort() # the following two lines are the two alogs, where # the first one corresponding to sending the whole feature order @@ -84,9 +89,10 @@ def __init__(self, self.ts = Test_base(self) - if config.criterion.type == 'CrossEntropyLoss': + self.criterion_type = config.criterion.type + if self.criterion_type == 'CrossEntropyLoss': self.ls = TwoClassificationloss() - elif config.criterion.type == 'Regression': + elif self.criterion_type == 'Regression': self.ls = Regression_by_mseloss() self.register_handlers('model_para', self.callback_func_for_model_para) @@ -195,40 +201,8 @@ def compute_weight(self, tree_num, node_num): else: self.y_hat += self.z self.z = 0 - metric = self.ls.metric(self.y, self.y_hat) - - if tree_num + 1 == self.num_of_trees: - self.comm_manager.send( - Message(msg_type='test', - sender=self.ID, - state=self.state, - receiver=self.server_id, - content=None)) - self.comm_manager.send( - Message(msg_type='send_feature_importance', - sender=self.ID, - state=self.state, - receiver=[ - each for each in list( - self.comm_manager.neighbors.keys()) - if each != self.server_id - ], - content=None)) - self.comm_manager.send( - Message(msg_type='feature_importance', - sender=self.ID, - state=self.state, - receiver=self.server_id, - content=self.feature_importance)) - else: - self.state += 1 - logger.info( - f'----------- Starting a new training round (Round ' - f'#{self.state}) -------------') - tree_num += 1 - # to build the next tree - self.fs.compute_for_root(tree_num) + self.ts.test_for_root(tree_num) else: if self.tree_list[tree_num][node_num].weight: self.z += self.tree_list[tree_num][ diff --git a/federatedscope/vertical_fl/xgb_base/worker/XGBServer.py b/federatedscope/vertical_fl/xgb_base/worker/XGBServer.py index e6ab0c407..4b99302de 100644 --- a/federatedscope/vertical_fl/xgb_base/worker/XGBServer.py +++ b/federatedscope/vertical_fl/xgb_base/worker/XGBServer.py @@ -47,7 +47,6 @@ def __init__(self, ] self.feature_importance_dict = dict() - self.register_handlers('test', self.callback_func_for_test) self.register_handlers('test_result', self.callback_func_for_test_result) self.register_handlers('feature_importance', @@ -67,6 +66,7 @@ def broadcast_model_para(self): content=(self.lambda_, self.gamma, self.num_of_trees, self.max_tree_depth))) + # TODO: merge the following two callback funcs def callback_func_for_feature_importance(self, message: Message): feature_importance = message.content self.feature_importance_dict[message.sender] = feature_importance @@ -74,38 +74,19 @@ def callback_func_for_feature_importance(self, message: Message): self.feature_importance_dict = dict( sorted(self.feature_importance_dict.items(), key=lambda x: x[0])) - - def callback_func_for_test(self, message: Message): - test_x = self.data['test']['x'] - test_y = self.data['test']['y'] - for i in range(self.num_of_parties): - test_data = test_x[:, - self.feature_list[i]:self.feature_list[i + 1]] - self.comm_manager.send( - Message(msg_type='test_data', - sender=self.ID, - receiver=i + 1, - state=self.state, - content=test_data)) - self.comm_manager.send( - Message(msg_type='test_value', - sender=self.ID, - receiver=self.num_of_parties, - state=self.state, - content=test_y)) + self._monitor.update_best_result(self.best_results, + self.metrics, + results_type='server_global_eval') + self._monitor.add_items_to_best_result( + self.best_results, + self.feature_importance_dict, + results_type='feature_importance') + formatted_logs = self._monitor.format_eval_res( + self.metrics, + rnd=self.tree_num, + role='Server #', + forms=self._cfg.eval.report) + logger.info(formatted_logs) def callback_func_for_test_result(self, message: Message): - metrics = message.content - self._monitor.update_best_result(self.best_results, - metrics, - results_type='server_global_eval') - self._monitor.add_items_to_best_result( - self.best_results, - self.feature_importance_dict, - results_type='feature_importance') - formatted_logs = self._monitor.format_eval_res( - metrics, - rnd=self.state, - role='Server #', - forms=self._cfg.eval.report) - logger.info(formatted_logs) + self.tree_num, self.metrics = message.content