diff --git a/examples/community/vfl_attacks/README.md b/examples/community/vfl_attacks/README.md new file mode 100644 index 0000000..4c38234 --- /dev/null +++ b/examples/community/vfl_attacks/README.md @@ -0,0 +1,93 @@ +# 纵向联邦学习攻击 + +本项目基于MindSpore框架实现了纵向联邦学习模型的一种后门攻击方法和一种标签推理攻击方法。 + +## 原型论文 + +T. Zou, Y. Liu, Y. Kang, W. Liu, Y. He, Z. Yi, Q. Yang, and Y.-Q. Zhang, “Defending batch-level label inference and replacement attacks in vertical federated learning,” IEEE Transactions on Big Data, pp. 1–12, 2022. [PDF][https://www.computer.org/csdl/journal/bd/5555/01/09833321/1F8uKhxrvNe] + +Fu C, Zhang X, Ji S, et al. Label inference attacks against vertical federated learning[C]//31st USENIX security symposium (USENIX Security 22). 2022: 1397-1414. [PDF][https://www.usenix.org/conference/usenixsecurity22/presentation/fu-chong] + +## 环境要求 + +Mindspore >= 1.9 + +## 脚本说明 + +```markdown +│ README.md +│ the_example.py // 应用示例 +│ +├─examples //示例 +│ ├─common +│ │ │ constants.py //用户定义常量 +│ │ +│ ├─datasets +│ │ │ cifar_dataset.py //用户加载数据集 +│ │ │ functions.py +│ │ +│ └─model +│ │ init_active_model.py //用户加载顶层模型 +│ │ init_passive_model.py //用户加载底层模型 +│ │ resnet.py //用户定义模型结构 +│ │ resnet_cifar.py +│ │ top_model_fcn.py +│ │ vgg.py +│ │ vgg_cifar.py +│ +├─utils //实现VFL功能和两个算法 +│ ├─config +│ │ │ args_process.py //审查并处理用户传入的参数 +│ │ │ config.yaml //默认参数配置文件 +│ │ +│ ├─datasets //定义VFL数据集加载方式 +│ │ │ common.py +│ │ │ image_dataset.py +│ │ +│ ├─methods +│ │ ├─direct_attack +│ │ │ │ direct_attack_passive_party.py //定义直接标签推理攻击的攻击者对象 +│ │ │ │ direct_attack_vfl.py //定义直接标签推理攻击的VFL对象 +│ │ │ +│ │ └─g_r +│ │ │ g_r_passive_party.py //定义梯度替换后门攻击的攻击者对象 +│ │ +│ ├─model +│ │ │ base_model.py //VFL中模型的基本类 +│ │ +│ ├─party +│ │ │ active_party.py //主动方对象 +│ │ │ passive_party.py //被动方对象 +│ │ +│ └─vfl +│ │ init_vfl.py //初始化各参与方 +│ │ vfl.py //定义VFL对象,包括各类过程函数 +``` + +## 引入相关包 + +```Python +from utils.vfl.init_vfl import Init_Vfl +from utils.vfl.vfl import VFL +from utils.methods.direct_attack.direct_attack_vfl import DirectVFL +from utils.config.args_process import argsments_function +``` + +## Init_Vfl介绍 + +该模块负责垂直联邦学习(VFL)中参与者的初始化,包括参与者的模型、参数和类。对于主动参与方,定义对象为VFLActiveModel,对于正常被动参与方,定义对象为VFLPassiveModel,对于梯度替换后门攻击,定义攻击者对象为GRPassiveModel,对于直接标签推理攻击,定义对象为DirectAttackPassiveModel。 + +## VFL介绍 + +该模块定义了VFL中各种过程函数,包括训练、预测、更新等。 + +## DirectVFL 介绍 + +该模块实现了直接标签推理攻击,在VFL类的基础上定义了直接标签推理攻击中的过程函数。 + +## argsments_function 介绍 + +该函数接受并审查用户的参数,并封装为utils中各个类支持的格式。其中,梯度替换攻击用于分割VFL场景,直接标签推理攻击只适用于不分割VFL场景。 + +## 扩展 +本项目当前支持CIFAR-10、BHI数据集,目前examples/datasets文件夹中给出了CIFAR-10数据集加载代码,BHI数据集参考该代码进行扩展即可。如需自定义模型结构或数据集加载方式,请参考并修改examples文件夹中的对应文件内容。 \ No newline at end of file diff --git a/examples/community/vfl_attacks/__init__.py b/examples/community/vfl_attacks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/community/vfl_attacks/examples/__init__.py b/examples/community/vfl_attacks/examples/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/community/vfl_attacks/examples/common/constants.py b/examples/community/vfl_attacks/examples/common/constants.py new file mode 100644 index 0000000..334cb94 --- /dev/null +++ b/examples/community/vfl_attacks/examples/common/constants.py @@ -0,0 +1,21 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +This module defines constant variables. +""" + +checkpoint_path = './output_logs' +output_path = './output_logs/' +data_path = '../data/' \ No newline at end of file diff --git a/examples/community/vfl_attacks/examples/datasets/__init__.py b/examples/community/vfl_attacks/examples/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/community/vfl_attacks/examples/datasets/cifar_dataset.py b/examples/community/vfl_attacks/examples/datasets/cifar_dataset.py new file mode 100644 index 0000000..1567889 --- /dev/null +++ b/examples/community/vfl_attacks/examples/datasets/cifar_dataset.py @@ -0,0 +1,197 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Load the dataset and construct the dataloader. + +This module provides functions to create data loaders for the CIFAR-10 and CIFAR-100 datasets, +including train data loader, test data loader, and backdoor test data loader. +""" +import os +import pickle +import numpy as np +import mindspore as ms +from mindspore.dataset import vision +from utils.datasets.common import generate_dataloader +from examples.datasets.functions import get_random_indices, get_target_indices +from examples.common.constants import data_path + +# Transform for CIFAR train dataset. +train_transform = ms.dataset.transforms.Compose([ + vision.ToTensor() +]) + +# Transform for CIFAR test dataset. +test_transform = ms.dataset.transforms.Compose([ + vision.ToTensor() +]) + +def _get_labeled_data_with_2_party(data_dir, dataset, dtype="train", num_samples=None): + """ + Read data from a local file. + + Args: + data_dir (str): Directory path of the local file. + dataset (str): Dataset name, supported values are 'cifar10' and 'cifar100'. + dtype (str): Type of data to read, either "Train" or "Test". + + Returns: + tuple: A tuple containing the data X and the labels Y. + """ + if dataset == 'cifar10': + data_dir = data_dir + 'cifar-10-batches-py/' + train_list = [ + 'data_batch_1', + 'data_batch_2', + 'data_batch_3', + 'data_batch_4', + 'data_batch_5'] + test_list = ['test_batch'] + all_data = [] + targets = [] + downloaded_list = train_list if dtype == 'train' else test_list + for file_name in downloaded_list: + file_path = os.path.join(data_dir, file_name) + with open(file_path, 'rb') as f: + entry = pickle.load(f, encoding='latin1') + all_data.append(entry['data']) + if 'labels' in entry: + targets.extend(entry['labels']) + else: + targets.extend(entry['fine_labels']) + all_data = np.vstack(all_data).reshape(-1, 3, 32, 32) + targets = np.array(targets) + if num_samples is not None: + indices = get_random_indices(num_samples, len(all_data)) + datas, labels = all_data[indices], targets[indices] + else: + datas, labels = all_data, targets + else: + filename = data_dir + 'cifar-100-python/' + dtype + with open(filename, 'rb') as f: + datadict = pickle.load(f, encoding='latin1') + x = datadict['data'] + all_data = x.reshape(-1, 3, 32, 32) + targets = datadict['fine_labels'] + targets = np.array(targets) + if num_samples is not None: + indices = get_random_indices(num_samples, len(all_data)) + datas, labels = all_data[indices], targets[indices] + else: + datas, labels = all_data, targets + + return datas, labels + + +def _load_two_party_data(data_dir, args): + """ + Get data from a local dataset, supporting only two parties. + + Args: + data_dir (str): Path of the local dataset. + args (dict): Configuration. + + Returns: + tuple: A tuple containing the following data: + X_train: Normal train features. + y_train: Normal train labels. + X_test: Normal test features. + y_test: Normal test labels. + backdoor_X_test: Backdoor test features. + backdoor_y_test: Backdoor test labels. + backdoor_indices_train: Indices of backdoor samples in the normal train dataset. + backdoor_target_indices: Indices of backdoor labels in the normal train dataset. + """ + print("# load_two_party_data") + n_train = args['target_train_size'] + n_test = args['target_test_size'] + if n_train == -1: + n_train = None + if n_test == -1: + n_test = None + + x_train, y_train = _get_labeled_data_with_2_party(data_dir=data_dir, + dataset=args['dataset'], + dtype='train', + num_samples=n_train) + + x_test, y_test = _get_labeled_data_with_2_party(data_dir=data_dir, + dataset=args['dataset'], + dtype='test', + num_samples=n_test) + + # Randomly select samples of other classes from normal train dataset as backdoor samples. + train_indices = np.where(y_train != args['backdoor_label'])[0] + backdoor_indices_train = np.random.choice(train_indices, args['backdoor_train_size'], replace=False) + + # Randomly select samples of other classes from normal test dataset to generate backdoor test dataset. + test_indices = np.where(y_test != args['backdoor_label'])[0] + backdoor_indices_test = np.random.choice(test_indices, args['backdoor_test_size'], replace=False) + backdoor_x_test, backdoor_y_test = x_test[backdoor_indices_test], \ + y_test[backdoor_indices_test] + backdoor_y_test = np.full_like(backdoor_y_test, args['backdoor_label']) + + # Randomly select samples of backdoor label in normal train dataset, for gradient-replacement. + backdoor_target_indices = get_target_indices(y_train, args['backdoor_label'], args['backdoor_train_size']) + + print(f"y_train.shape: {y_train.shape}") + print(f"y_test.shape: {y_test.shape}") + print(f"backdoor_y_test.shape: {backdoor_y_test.shape}") + + return x_train, y_train, x_test, y_test, backdoor_x_test, backdoor_y_test, \ + backdoor_indices_train, backdoor_target_indices + + +def get_cifar_dataloader(args): + """ + Generate loaders for the CIFAR dataset, supporting CIFAR-10 and CIFAR-100. + + Args: + args (dict): Configuration. + + Returns: + tuple: A tuple containing the following data loaders: + train_dl: Loader for the normal train dataset. + test_dl: Loader for the normal test dataset. + backdoor_test_dl: Loader for the backdoor test dataset, containing only backdoor samples, + used for ASR evaluation. + backdoor_indices: Indices of backdoor samples in the normal train dataset. + backdoor_target_indices: Indices of backdoor labels in the normal train dataset, + used by Gradient-Replacement. + """ + result = _load_two_party_data(data_path, args) + x_train, y_train, x_test, y_test, backdoor_x_test, backdoor_y_test, \ + backdoor_indices, backdoor_target_indices = result + + batch_size = args['target_batch_size'] + # Get loader of normal train dataset, used by normal training. + train_dl = generate_dataloader((x_train, y_train), batch_size, train_transform, shuffle=True, half=args['half']) + # GFet loader of normal test dataset, used to evaluate main task accuracy. + test_dl = generate_dataloader((x_test, y_test), batch_size, test_transform, shuffle=False, half=args['half']) + + backdoor_test_dl = None + if args['backdoor'] != 'no': + # Get loader of backdoor test dataset, used to evaluate backdoor task accuracy. + backdoor_test_dl = generate_dataloader((backdoor_x_test, backdoor_y_test), batch_size, test_transform, + shuffle=False, + backdoor_indices=np.arange(args['backdoor_test_size']), + trigger=args['trigger'], trigger_add=args['trigger_add'], half=args['half']) + + if args['backdoor'] == 'g_r': + # Get loader of train dataset used by Gradient-Replacement, containing backdoor features and normal labels. + train_dl = generate_dataloader((x_train, y_train), batch_size, train_transform, + shuffle=True, + backdoor_indices=backdoor_indices, half=args['half']) + + return train_dl, test_dl, backdoor_test_dl, backdoor_indices, backdoor_target_indices \ No newline at end of file diff --git a/examples/community/vfl_attacks/examples/datasets/functions.py b/examples/community/vfl_attacks/examples/datasets/functions.py new file mode 100644 index 0000000..59c38be --- /dev/null +++ b/examples/community/vfl_attacks/examples/datasets/functions.py @@ -0,0 +1,56 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +This module provides functions to select samples from a dataset. + +Functions: + _get_target_indices chooses the samples of the specified category. + _get_random_indices randomly selects samples. +""" +import numpy as np + +def get_target_indices(labels, target_label, size, backdoor_indices=None): + """ + Get indices with a specified size of the target label. + + Args: + labels (ndarray): Array of labels in the dataset. + target_label (int): The target label to filter. + size (int): The number of indices to return. + + Returns: + ndarray: An array of indices with the specified size of the target label. + """ + indices = np.where(labels == target_label)[0] + indices = np.setdiff1d(indices, backdoor_indices) + np.random.shuffle(indices) + result = indices[:size] + return result + + +def get_random_indices(target_length, all_length): + """ + Generate random indices. + + Args: + target_length (int): The length of the target indices to generate. + all_length (int): The total length of all indices available. + + Returns: + ndarray: An array of random indices. + """ + all_indices = np.arange(all_length) + indices = np.random.choice(all_indices, target_length, replace=False) + return indices diff --git a/examples/community/vfl_attacks/the_example.py b/examples/community/vfl_attacks/the_example.py new file mode 100644 index 0000000..e7ebd6b --- /dev/null +++ b/examples/community/vfl_attacks/the_example.py @@ -0,0 +1,73 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import sys +import os +sys.path.append(os.path.dirname(os.path.abspath(__file__))+'/utils/') +import mindspore +from utils.vfl.init_vfl import Init_Vfl +from utils.vfl.vfl import VFL +from utils.methods.direct_attack.direct_attack_vfl import DirectVFL +from utils.config.args_process import argsments_function +from examples.model.init_active_model import init_top_model +from examples.model.init_passive_model import init_bottom_model +from examples.datasets.cifar_dataset import get_cifar_dataloader +mindspore.set_context(mode=mindspore.GRAPH_MODE) + +if __name__ == '__main__': + gpu = 0 + if gpu: + mindspore.set_context(device_target="GPU") + epochs = 100 + batch_size = 32 + lr = 0.005 + top_model_trainable = 0 + n_party = 2 + adversary = 1 + num_classes = 10 + topk = 1 + target_train_size = -1 + target_test_size = -1 + backdoor_test_size = 2000 + # Support g_r and direct_attack. + alg = 'direct_attack' + backdoor_label = 6 + poison_num = 500 + # Process the arguments. + args = argsments_function(epochs, batch_size, lr, top_model_trainable, n_party, adversary, gpu, + target_train_size, target_test_size, backdoor_test_size, backdoor_label, + alg, num_classes, poison_num, topk) + local_args = args.copy() + local_args['dataset'] = 'cifar10' + local_args['half'] = 16 + local_args['model_type'] = 'Resnet' + # Define dataset loader. + train_dl, test_dl, backdoor_test_dl, backdoor_indices, backdoor_target_indices = get_cifar_dataloader(local_args) + # Define models. + bottoms = [init_bottom_model('active', local_args)] + for i in range(0, args['n_passive_party']): + passive_party_model = init_bottom_model('passive', local_args) + bottoms.append(passive_party_model) + top = init_top_model(local_args) + # Initialize vfl framework. + if alg == 'g_r': + init_vfl = Init_Vfl(args) + init_vfl.get_vfl(bottoms, top, train_dl, backdoor_target_indices, backdoor_indices) + VFL_framework = VFL(train_dl, test_dl, init_vfl, backdoor_test_dl) + VFL_framework.train() + elif alg == 'direct_attack': + init_vfl = Init_Vfl(args) + init_vfl.get_vfl(bottoms, top, train_dl) + VFL_framework = DirectVFL(train_dl, test_dl, init_vfl) + VFL_framework.train() \ No newline at end of file diff --git a/examples/community/vfl_attacks/utils/config/args_process.py b/examples/community/vfl_attacks/utils/config/args_process.py new file mode 100644 index 0000000..d93798e --- /dev/null +++ b/examples/community/vfl_attacks/utils/config/args_process.py @@ -0,0 +1,95 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +This module receive and process configuration to ensure they conform to the VFL. +""" +import os +import yaml + +__all__ = ["argsments_function"] + +def argsments_function(epochs=100, batch_size=64, lr=0.01, top_model_trainable=1, n_party=2, adversary=1, + gpu=0, target_train_size=-1, target_test_size=-1, backdoor_test_size=2000, backdoor_label=0, + alg='no', num_classes=10, poison_num=10, topk=1): + """ + Arguments processing. + + Args: + epochs (int): Training epochs. + batch_size (int): The size of each batch. + lr (float): Learning rate. + top_model_trainable (int): VFL with splitting or VFL without splitting. + n_party (int): The number of parties. + adversary (int): The ID of the attacker. + gpu (int): Use GPU or not. + target_train_size (int): The size of the train dataset. + target_test_size (int): The size of the test dataset. + backdoor_test_size (int): The size of the backdoor test dataset. + backdoor_label (int): Backdoor target label. + alg (str): The name of the algorithm. + num_classes (int): The number of classes. + poison_num (int): The number of poisoned samples. + topk (int): The top-k metric. + + Returns: + dict: The processed arguments. + """ + if alg == 'direct_attack' and top_model_trainable: + return + if alg == 'g_r' and not top_model_trainable: + return + if n_party != 2: + return + if adversary != 1: + return + if backdoor_label >= num_classes: + return + yaml.warnings({'YAMLLoadWarning': False}) + f = open(os.path.dirname(os.path.abspath(__file__)) + '/config.yaml', 'r', encoding='utf-8') + cfg = f.read() + args = yaml.load(cfg, Loader=yaml.SafeLoader) + f.close() + args['num_classes'] = num_classes + args['target_epochs'] = epochs + args['passive_bottom_lr'] = lr + args['active_bottom_lr'] = lr + args['active_top_lr'] = lr + args['target_batch_size'] = batch_size + args['passive_bottom_gamma'] = 0.1 + args['active_bottom_gamma'] = 0.1 + args['active_top_gamma'] = 0.1 + args['cuda'] = gpu + args['active_top_trainable'] = bool(top_model_trainable) + args['n_passive_party'] = n_party - 1 + args['adversary'] = adversary + args['target_train_size'] = target_train_size + args['target_test_size'] = target_test_size + args['backdoor_test_size'] = backdoor_test_size + args['backdoor_label'] = backdoor_label + args['topk'] = topk + args['aggregate'] = 'Concate' + if alg == 'g_r': + args['attack'] = True + args['backdoor'] = alg + args['label_inference_attack'] = 'no' + args['trigger'] = 'pixel' + args['trigger_add'] = False + if alg == 'direct_attack': + args['attack'] = True + args['backdoor'] = 'no' + args['label_inference_attack'] = alg + args['amplify_ratio'] = 1 + args['backdoor_train_size'] = poison_num + return args diff --git a/examples/community/vfl_attacks/utils/config/config.yaml b/examples/community/vfl_attacks/utils/config/config.yaml new file mode 100644 index 0000000..83513d8 --- /dev/null +++ b/examples/community/vfl_attacks/utils/config/config.yaml @@ -0,0 +1,60 @@ +# global configuration +half: 16 +cuda: 1 +log: vfl +debug: False + +# trigger +trigger_add: False +# data configuration +target_train_size: -1 +target_test_size: -1 +backdoor_train_size: 500 +backdoor_test_size: 2000 +train_label_size: 40 +#target_train_size: 2000 +#target_test_size: 2000 +#backdoor_train_size: 50 +#backdoor_test_size: 200 +#train_label_size: 40 + +# save configuration +save_model: 0 +save_data: 0 + +# load configuration +load_target: 0 +load_data: 0 +load_model: 0 +load_time: 0 + +# global model configuration +n_passive_party: 1 +target_batch_size: 32 +target_epochs: 100 +# passive party configuration +passive_bottom_model: resnet +passive_bottom_gamma: 0.1 +passive_bottom_wd: 0.0005 +passive_bottom_momentum: 0.9 +passive_bottom_lr: 0.1 +passive_bottom_stone: [50,85,100] +# active party configuration +# active bottom model configuration +active_bottom_model: resnet +active_bottom_gamma: 0.1 +active_bottom_wd: 0.0005 +active_bottom_momentum: 0.9 +active_bottom_lr: 0.1 +active_bottom_stone: [50,85,100] +# active top model configuration +active_top_trainable: 1 +active_top_model: fcn +active_top_gamma: 0.1 +active_top_wd: 0.0005 +active_top_momentum: 0.9 +active_top_lr: 0.1 +active_top_stone: [50,85,100] + +# backdoor attack global configuration +backdoor_label: 6 \ No newline at end of file diff --git a/examples/community/vfl_attacks/utils/datasets/common.py b/examples/community/vfl_attacks/utils/datasets/common.py new file mode 100644 index 0000000..b010ac9 --- /dev/null +++ b/examples/community/vfl_attacks/utils/datasets/common.py @@ -0,0 +1,72 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Define data loaders for Vertical Federated Learning (VFL). + +This module provides functions to create data loaders specifically designed for VFL scenarios. +""" +import mindspore as ms +from datasets.image_dataset import ImageDataset + +__all__ = ["generate_dataloader"] + +def generate_dataloader(data_list, batch_size, transform=None, shuffle=True, backdoor_indices=None, + trigger=None, trigger_add=None, source_indices=None, half=16): + """ + Generate loader from dataset. + + Args: + data_list (tuple): Contains X and Y. + batch_size (int): Batch size of the loader. + transform: Transform of the loader. + shuffle (bool): Whether to shuffle the loader. + backdoor_indices (ndarray): Indices of backdoor samples in normal dataset. + Adds trigger when loading data if index is in backdoor_indices. + trigger (str): Controls whether to add a pixel trigger when loading the data loader. + half (int): An integer specifying the size of a data subset + Returns: + DataLoader: The generated loader. + """ + x, y = data_list + + ImageDatasetWithIndices = _image_dataset_with_indices(ImageDataset) + # Split x into halves for parties when loading data, only support two parties. + ds = ImageDatasetWithIndices(x, ms.tensor(y, ms.int32), + transform=transform, + backdoor_indices=backdoor_indices, + half=half, trigger=trigger, trigger_add=trigger_add, source_indices=source_indices) + dl = ms.dataset.GeneratorDataset(source=ds, shuffle=shuffle, column_names=['image','target','old_imgb', 'indice']) + dl = dl.batch(batch_size, drop_remainder=False) + return dl + + +def _image_dataset_with_indices(cls): + """ + Build dataset class that can output x, y, and index when loading data based on cls, used for image dataset. + + Args: + cls: The original dataset class. + + Returns: + type: New dataset class. + """ + def __getitem__(self, index): + X_data, target, original = cls.__getitem__(self, index) + return X_data, target, original, index + + type_of_cls = type(cls.__name__, (cls,), { + '__getitem__': __getitem__, + }) + return type_of_cls diff --git a/examples/community/vfl_attacks/utils/datasets/image_dataset.py b/examples/community/vfl_attacks/utils/datasets/image_dataset.py new file mode 100644 index 0000000..f817554 --- /dev/null +++ b/examples/community/vfl_attacks/utils/datasets/image_dataset.py @@ -0,0 +1,177 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Define image dataset, only support two parties, used for CIFAR and CINIC +""" +from typing import Any, Callable, Optional +import numpy as np +import mindspore as ms + +class ImageDataset(object): + """ + The vfl dataset, support image 3,32,32 and two parties. + """ + def __init__( + self, + X, y, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + backdoor_indices=None, + half=None, + trigger=None, + trigger_add=False, + source_indices=None + ) -> None: + """ + Construction of dataset for scenarios with two participants. + + Args: + X (ndarray): The data. + y (ndarray): The labels. + transform (callable, optional): Transform for image. + target_transform (callable, optional): Transformer for labels. + backdoor_indices (ndarray): The indices of poison samples. + half (int, optional): Default is 16. + trigger (str, optional): Default is 'pixel'. Add trigger on poison samples; otherwise, + do not add trigger on poison samples. + trigger_add (bool, optional): Whether the trigger is additive or a mask trigger. + source_indices (list, optional): The indices of samples used to replace poisoned samples. + """ + self.transform = transform + self.target_transform = target_transform + + self.data: Any = [] + self.targets = [] + + self.data = X + self.targets = y + + self.backdoor_indices = backdoor_indices + self.source_indices = source_indices + + if backdoor_indices is not None and source_indices is not None: + self.indice_map = dict(zip(backdoor_indices, source_indices)) + else: + self.indice_map = None + + self.half = half + self.trigger = trigger + if self.trigger is None: + self.trigger = 'pixel' + self.trigger_add = trigger_add + channels, height, width = self.data[0].shape + if channels in {1, 3, 4}: + self.pixel_pattern = np.full((channels, height, half), 0) + else: + self.pixel_pattern = np.full((width, height, half), 0) + + pattern_mask: ms.Tensor = ms.tensor([ + [1., 0., 1.], + [-10., 1., -10.], + [-10., -10., 0.], + [-10., 1., -10.], + [1., 0., 1.] + ], dtype=ms.float32) + + pattern_mask = pattern_mask.unsqueeze(0) + self.pattern_mask = pattern_mask.tile((3, 1, 1)) + self.pattern_mask = self.pattern_mask.asnumpy() + x_top = 3 + y_top = 3 + x_bot = x_top + self.pattern_mask.shape[1] + y_bot = y_top + self.pattern_mask.shape[2] + self.location = [x_top, x_bot, y_top, y_bot] + + def __getitem__(self, index): + """ + Get data. + + Args: + index (int): The index of the sample. + + Returns: + tuple: A tuple containing: + imgs (ndarray): Clean data or poisoned data. + target (int): The label. + old_img (ndarray): The original data without poisoning. + """ + index = int(index) + img, target = self.data[index], self.targets[index] + if img.shape[0] in {1, 3, 4}: + img = img.transpose(1, 2, 0) + if self.transform is not None: + img = self.transform(img) + img_a, img_b = img[:, :, :self.half], img[:, :, self.half:] + + if self.target_transform is not None: + target = self.target_transform(target) + + old_imgb = img_b + + if self.trigger == 'pixel': + if self.indice_map is not None and index in self.indice_map.keys(): + source_indice = self.indice_map[index] + source_img = self.data[source_indice] + source_img = source_img.transpose(1, 2, 0) + if self.transform is not None: + source_img = self.transform(source_img) + img_b = source_img[:, :, self.half:] + + # add trigger if index is in backdoor indices + if self.trigger == 'pixel': + if self.backdoor_indices is not None and index in self.backdoor_indices: + if self.trigger_add: + img_b = img_b + self.pixel_pattern + else: + img_b = _add_pixel_pattern_backdoor(img_b, self.pattern_mask, self.location) + imgs = (img_a, img_b) + return imgs, target, old_imgb + + def __len__(self) -> int: + """ + Get the length of the dataset. + """ + return len(self.data) + + +def _add_pixel_pattern_backdoor(inputs, pattern_tensor, location): + """ + Add pixel pattern trigger to image. + + Args: + inputs (Tensor): Normal images. + pattern_tensor (ndarray): The additive trigger. + location (list or tuple): The area to put the trigger. + + Returns: + Tensor: Images with the trigger. + """ + mask_value = -10 + + input_shape = inputs.shape + full_image = np.full(input_shape, mask_value, dtype=inputs.dtype) + + x_top = location[0] + x_bot = location[1] + y_top = location[2] + y_bot = location[3] + full_image[:, x_top:x_bot, y_top:y_bot] = pattern_tensor + + mask = 1 * (full_image != mask_value) + pattern = full_image + + inputs = (1 - mask) * inputs + mask * pattern + inputs = inputs.astype(pattern_tensor.dtype) + return inputs diff --git a/examples/community/vfl_attacks/utils/methods/direct_attack/direct_attack_passive_party.py b/examples/community/vfl_attacks/utils/methods/direct_attack/direct_attack_passive_party.py new file mode 100644 index 0000000..cfb367c --- /dev/null +++ b/examples/community/vfl_attacks/utils/methods/direct_attack/direct_attack_passive_party.py @@ -0,0 +1,57 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Define the passive party for direct label inference attack. + +This module provides the implementation of the passive party in scenarios involving direct label inference attacks. +""" +from party.passive_party import VFLPassiveModel + +class DirectAttackPassiveModel(VFLPassiveModel): + """ + Malicious passive party for direct label inference attack. + """ + def __init__(self, bottom_model, id=None, args=None): + VFLPassiveModel.__init__(self, bottom_model, id, args) + self.batch_label = None + self.inferred_correct = 0 + self.inferred_wrong = 0 + + def send_components(self): + """ + Send latent representation to active party. + """ + result = self._forward_computation(self.X) + return result + + def receive_gradients(self, gradients): + """ + Receive gradients from the active party and update parameters of the local bottom model. + + Args: + gradients (Tensor): Gradients from the active party. + """ + for sample_id in range(len(gradients)): + grad_per_sample = gradients[sample_id] + for logit_id in range(len(grad_per_sample)): + if grad_per_sample[logit_id] < 0: + inferred_label = logit_id + if inferred_label == self.batch_label[sample_id]: + self.inferred_correct += 1 + else: + self.inferred_wrong += 1 + break + self.common_grad = gradients + self._fit(self.X, self.y) diff --git a/examples/community/vfl_attacks/utils/methods/direct_attack/direct_attack_vfl.py b/examples/community/vfl_attacks/utils/methods/direct_attack/direct_attack_vfl.py new file mode 100644 index 0000000..c1d5919 --- /dev/null +++ b/examples/community/vfl_attacks/utils/methods/direct_attack/direct_attack_vfl.py @@ -0,0 +1,75 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Vertical Federated Learning (VFL) for direct label inference attacks. +""" +import numpy as np +from mindspore import ops +from vfl.vfl import VFL + +__all__ = ["DirectVFL"] + +class DirectVFL(VFL): + """ + VFL for direct label inference attack. + """ + def train(self): + """ + Load or train the vfl models. + """ + if self.args['load_model']: + raise ValueError('do not support load model for direct label inference attack') + else: + for ep in range(self.args['target_epochs']): + loss_list = [] + self.set_train() + self.set_current_epoch(ep) + self.party_dict[self.adversary].inferred_correct = 0 + self.party_dict[self.adversary].inferred_wrong = 0 + + for _, (X, Y_batch, _, indices) in enumerate(self.train_loader): + party_X_train_batch_dict = dict() + if self.args['n_passive_party'] < 2: + X = ops.transpose(X, (1, 0, 2, 3, 4)) + active_X_batch, Xb_batch = X + party_X_train_batch_dict[0] = Xb_batch + else: + active_X_batch = X[:, 0:1].squeeze(1) + for i in range(self.args['n_passive_party']): + party_X_train_batch_dict[i] = X[:, i+1:i+2].squeeze(1) + + self.party_dict[self.adversary].batch_label = Y_batch + + loss, _ = self.fit(active_X_batch, Y_batch, party_X_train_batch_dict, indices) + loss_list.append(loss) + self.scheduler_step() + + # Compute main-task accuracy. + ave_loss = np.sum(loss_list)/len(self.train_loader.children[0]) + self.set_state('train') + self.train_acc = self.predict(self.train_loader, num_classes=self.args['num_classes'],top_k=self.top_k, + n_passive_party=self.args['n_passive_party']) + self.set_state('test') + self.test_acc = self.predict(self.test_loader, num_classes=self.args['num_classes'],top_k=self.top_k, + n_passive_party=self.args['n_passive_party']) + self.inference_acc = self.party_dict[self.adversary].inferred_correct / \ + (self.party_dict[self.adversary].inferred_correct + + self.party_dict[self.adversary].inferred_wrong) + print(f"--- epoch: {ep}, train loss: {ave_loss}, train_acc: {self.train_acc * 100}%, " + f"test acc: {self.test_acc * 100}%, direct label inference accuracy: {self.inference_acc}") + self.record_train_acc.append(self.train_acc) + self.record_test_acc.append(self.test_acc) + self.record_loss.append(ave_loss) + self.record_attack_metric.append(self.inference_acc) diff --git a/examples/community/vfl_attacks/utils/methods/g_r/g_r_passive_party.py b/examples/community/vfl_attacks/utils/methods/g_r/g_r_passive_party.py new file mode 100644 index 0000000..d0a233a --- /dev/null +++ b/examples/community/vfl_attacks/utils/methods/g_r/g_r_passive_party.py @@ -0,0 +1,105 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +This module defines the passive party for gradient-replacement attack. +""" +import random +from mindspore import ops +from party.passive_party import VFLPassiveModel + +class GRPassiveModel(VFLPassiveModel): + """ + Malicious passive party for gradient-replacement backdoor. + """ + def __init__(self, bottom_model, amplify_ratio=1, id=None, args=None): + super(GRPassiveModel, self).__init__(bottom_model, id=id, args=args) + self.backdoor_indices = None + self.target_grad = None + self.target_indices = None + self.amplify_ratio = amplify_ratio + self.components = None + self.is_debug = False + self.pair_set = {} + self.target_gradients = {} + self.backdoor_X = {} + + def set_epoch(self, epoch): + """ + Set the current epoch for the passive party. + Args: + epoch (int): The current epoch for the passive party. + """ + self.epoch = epoch + + def set_backdoor_indices(self, target_indices, backdoor_indices, backdoor_X): + """ + Set the target indices, backdoor indices and backdoor samples. + + Args: + target_indices (List[int]): Indices of samples labeled as the backdoor class in the normal train dataset. + backdoor_indices (List[int]): Indices of backdoor samples in the normal train dataset, + used for gradient replacement. + backdoor_X (Tensor): The poisoned samples. + """ + self.target_indices = target_indices + self.backdoor_indices = backdoor_indices + self.backdoor_X = backdoor_X + + def receive_gradients(self, gradients): + """ + Receive gradients from the active party and update parameters of the local bottom model. + + Args: + gradients (List[Tensor]): Gradients from the active party. + """ + gradients = gradients.copy() + # Get the target gradient of samples labeled backdoor class. + for index, i in enumerate(self.indices): + i = i.item() + if i in self.target_indices: + self.target_gradients[i] = gradients[index] + + # Replace the gradient of backdoor samples with the target gradient. + for index, j in enumerate(self.indices): + j = j.item() + if j in self.backdoor_indices: + for i, v in self.pair_set.items(): + if v == j: + target_grad = self.target_gradients[i] + if target_grad is not None: + gradients[index] = self.amplify_ratio * target_grad + break + + self.common_grad = gradients + self._fit(self.X, self.components) + + def send_components(self): + """ + Send latent representation to the active party. + """ + result = self._forward_computation(self.X) + self.components = result + send_result = result.copy() + for index, i in enumerate(self.indices): + i = i.item() + if i in self.target_indices: + if i not in self.pair_set.keys(): + j = self.backdoor_indices[random.randint(0, len(self.backdoor_indices)-1)] + self.pair_set[i] = j + else: + j = self.pair_set[i] + send_result[index] = self.bottom_model.forward(ops.unsqueeze(self.backdoor_X[j], 0))[0] + + return send_result diff --git a/examples/community/vfl_attacks/utils/party/active_party.py b/examples/community/vfl_attacks/utils/party/active_party.py new file mode 100644 index 0000000..465aacb --- /dev/null +++ b/examples/community/vfl_attacks/utils/party/active_party.py @@ -0,0 +1,328 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Defines various functions of the active party in Vertical Federated Learning (VFL). + +This module includes the implementation of various operations performed by the active party, +such as receiving components, training models, sending gradients, and other related tasks. +""" +import mindspore +import mindspore.nn as nn +import mindspore.ops as ops + + +class VFLActiveModel(object): + """ + VFL active party. + """ + def __init__(self, bottom_model, args, top_model=None): + super(VFLActiveModel, self).__init__() + self.bottom_model = bottom_model + self.is_debug = False + + self.classifier_criterion = nn.CrossEntropyLoss() + self.parties_grad_component_list = [] + self.X = None + self.y = None + self.bottom_y = None + self.top_grads = None + self.parties_grad_list = [] + self.epoch = None + self.indices = None + + self.top_model = top_model + self.top_trainable = True if self.top_model is not None else False + + self.args = args.copy() + + if self.args['cuda']: + mindspore.set_context(device_target="GPU") + + self.attack_indices = [] + self.grad_op = ops.GradOperation(get_by_list=True, sens_param=True) + self.bottom_gradient_function = self.grad_op(self.bottom_model, self.bottom_model.model.trainable_params()) + if self.top_trainable: + self.top_grad_op = ops.value_and_grad(self.forward_fn, grad_position=0, + weights=self.top_model.model.trainable_params()) + else: + self.top_grad_op = ops.value_and_grad(self.forward_fn, grad_position=0, weights=None) + + def set_indices(self, indices): + """ + Set the current indices for the active party. + + Args: + indices (List[int]): A list of sample indices to be set as the current indices + for the active party. + """ + self.indices = indices + + def set_epoch(self, epoch): + """ + Set the current epoch for the active party. + Args: + epoch (int): The current epoch for the active party. + """ + self.epoch = epoch + + def set_batch(self, X, y): + """ + Set the data and labels for the active party. + + Args: + X (Tensor): The input data for the active party. + y (Tensor): The labels corresponding to the input data. + """ + self.X = X + self.y = y + + def _fit(self, X, y): + """ + Compute gradients and update the local bottom model and top model. + + Args: + X (Tensor):The input data of the active party. + y (Tensor): Labels corresponding to the input data. + """ + # Get local latent representation + self.bottom_y = self.bottom_model.forward(X) + self.K_U = self.bottom_y + + # Compute gradients based on labels, including gradients for passive parties + self._compute_common_gradient_and_loss(y) + + # Update parameters of local bottom model and top model + self._update_models(X, y) + + def predict(self, X, component_list, type): + """ + Get the final prediction. + + Args: + X (Tensor): Feature of the active party. + component_list (List[Tensor]): Latent representations from passive parties. + + Returns: + Tensor: Predicted labels. + """ + # Get local latent representation + U = self.bottom_model.forward(X) + + + # Sum up latent representation in VFL without model splitting + if not self.top_trainable: + for comp in component_list: + U = U + comp + # Use top model to predict in VFL with model splitting + else: + if self.args['aggregate'] == 'Concate': + temp = ops.cat([U] + component_list, -1) + elif self.args['aggregate'] == 'Add': + temp = U + for comp in component_list: + temp = temp + comp + elif self.args['aggregate'] == 'Mean': + temp = U + for comp in component_list: + temp = temp + comp + temp = temp / (len(component_list)+1) + U = self.top_model.forward(temp) + result = ops.softmax(U, axis=1) + return result + + def receive_components(self, component_list): + """ + Receive latent representations from passive parties. + + Args: + component_list (List[Tensor]): Latent representations from passive parties. + """ + for party_component in component_list: + self.parties_grad_component_list.append(party_component) + + def fit(self): + """ + Backward. + """ + self.parties_grad_list = [] + self._fit(self.X, self.y) + self.parties_grad_component_list = [] + + def forward_fn(self, top_input, y): + """ + Provide a forward computation function for the `value_and_grad()` function. + + Args: + y (Tensor): The label tensor used for computing the forward pass. + """ + if not self.top_trainable: + U = top_input + else: + U = self.top_model.forward(top_input) + + class_loss = self.classifier_criterion(U, y) + + return class_loss + + def _compute_common_gradient_and_loss(self, y): + """ + Compute loss and gradients, including gradients for passive parties. + + Args: + y (Tensor): The label tensor used for computing the loss and gradients. + """ + U = self.K_U + + grad_comp_list = [self.K_U] + self.parties_grad_component_list + if not self.top_trainable: + temp = U + for grad_comp in self.parties_grad_component_list: + temp = temp + grad_comp + else: + if self.args['aggregate'] == 'Concate': + temp = ops.cat(grad_comp_list, -1) + elif self.args['aggregate'] == 'Add': + temp = grad_comp_list[0] + for comp in grad_comp_list[1:]: + temp = temp + comp + elif self.args['aggregate'] == 'Mean': + temp = grad_comp_list[0] + for comp in grad_comp_list[1:]: + temp = temp + comp + temp = temp / len(grad_comp_list) + + top_input = temp + # Compute gradients. + if self.top_trainable: + class_loss, (top_input_grad, para_grad_list) = self.top_grad_op(top_input, y) + self.para_grad_list = para_grad_list + grad_list = [] + if self.args['aggregate'] == 'Concate': + length = self.bottom_model.output_dim + for i in range(len(self.parties_grad_component_list) + 1): + grad_list.append(top_input_grad[:, i*length:(i+1)*length]) + else: + if self.args['aggregate'] == 'Mean': + for i in range(len(self.parties_grad_component_list) + 1): + grad_list.append(top_input_grad / (len(self.parties_grad_component_list) + 1)) + else: + for i in range(len(self.parties_grad_component_list) + 1): + grad_list.append(top_input_grad) + else: + class_loss, top_input_grad = self.top_grad_op(top_input,y) + grad_list = [] + for i in range(len(self.parties_grad_component_list) + 1): + grad_list.append(top_input_grad) + + # Save gradients of local bottom model. + self.top_grads = grad_list[0] + # Save gradients for passive parties. + for index in range(0, len(self.parties_grad_component_list)): + parties_grad = grad_list[index+1] + self.parties_grad_list.append(parties_grad) + + self.loss = class_loss.item()*self.K_U.shape[0] + + def send_gradients(self): + """ + Send gradients to passive parties. + + Returns: + List[Tensor]: A list of gradient tensors to be sent to passive parties. + """ + return self.parties_grad_list + + def _update_models(self, X, y): + """ + Update parameters of the local bottom model and top model. + + Args: + X (Tensor): Features of the active party. + y (Tensor): The labels. + """ + if self.top_trainable: + self.top_model.backward_(self.para_grad_list) + self.bottom_backward(X, self.bottom_y, self.top_grads) + + def bottom_backward(self, x, y, grad_wrt_output): + """ + Update the bottom model. + + Args: + x (Tensor): The input data. + y (Tensor): The model output. + grad_wrt_output (Tensor): The gradients with respect to the output. + """ + bottom_parma_grad = self.bottom_gradient_function(x, grad_wrt_output) + self.bottom_model.backward_(bottom_parma_grad) + + def get_loss(self): + return self.loss + + def save(self): + """ + Save model to local file. + """ + if self.top_trainable: + self.top_model.save(time=self.args['file_time']) + self.bottom_model.save(time=self.args['file_time']) + + def load(self): + """ + Load model from local file. + """ + if self.top_trainable: + self.top_model.load(time=self.args['load_time']) + self.bottom_model.load(time=self.args['load_time']) + + def set_train(self): + """ + Set train mode. + """ + if self.top_trainable: + self.top_model.set_train(True) + self.bottom_model.set_train(True) + + def set_eval(self): + """ + Set eval mode. + """ + if self.top_trainable: + self.top_model.set_train(False) + self.bottom_model.set_train(False) + + def scheduler_step(self): + """ + Adjust learning rate during training. + """ + if self.top_trainable and self.top_model.scheduler is not None: + self.top_model.scheduler.step() + if self.bottom_model.scheduler is not None: + self.bottom_model.scheduler.step() + + def set_args(self, args): + """ + Set arguments for the active party. + + Args: + args (dict): A dictionary of arguments for the active party. + """ + self.args = args + + def zero_grad(self): + """ + Clear gradients. + """ + pass diff --git a/examples/community/vfl_attacks/utils/party/passive_party.py b/examples/community/vfl_attacks/utils/party/passive_party.py new file mode 100644 index 0000000..83db585 --- /dev/null +++ b/examples/community/vfl_attacks/utils/party/passive_party.py @@ -0,0 +1,174 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Defines various functions of the passive party in Vertical Federated Learning (VFL). + +This module includes the implementation of various functions performed by the passive party, +such as calculating and uploading intermediate data, receiving gradients, updating models, and other related tasks. +""" +from mindspore import ops + +class VFLPassiveModel(object): + """ + VFL passive party. + """ + def __init__(self, bottom_model, id=None, args=None): + super(VFLPassiveModel, self).__init__() + self.bottom_model = bottom_model + self.is_debug = False + self.common_grad = None + self.X = None + self.indices = None + self.epoch = None + self.y = None + self.id = id + self.args = args + self.grad_op = ops.GradOperation(get_by_list=True, sens_param=True) + self.bottom_gradient_function = self.grad_op(self.bottom_model, self.bottom_model.model.trainable_params()) + + def set_epoch(self, epoch): + """ + Set the current epoch for the passive party. + Args: + epoch (int): The current epoch for the passive party. + """ + self.epoch = epoch + + def set_batch(self, X, indices): + """ + Set the current data and indices for the passive party. + + Args: + X (Tensor): The input data for the passive party. + indices (List[int]): A list of sample indices to be set as the current indices + for the passive party. + """ + self.X = X + self.indices = indices + + def _forward_computation(self, X, model=None): + """ + Perform the forward computation. + + Args: + X (Tensor): Features of the passive party. + model (Model): The model object, which is marked as invalid in this context. + + Returns: + Tensor: The latent representation of the passive party. + """ + if model is None: + A_U = self.bottom_model.forward(X) + else: + A_U = model.forward(X) + self.y = A_U + return A_U + + def _fit(self, X, y): + """ + Backward. + + Args: + X (Tensor): Features of the passive party. + y (Tensor): The latent representation of the passive party. + """ + self.bottom_backward(X, y, self.common_grad, self.epoch) + return + + def bottom_backward(self, x, y, grad_wrt_output, epoch): + """ + Update the bottom model. + + Args: + x (Tensor): The input data. + y (Tensor): The model output. + grad_wrt_output (Tensor): The gradients with respect to the output. + epoch (int): The current epoch. + """ + bottom_parma_grad = self.bottom_gradient_function(x, grad_wrt_output) + self.bottom_model.backward_(bottom_parma_grad) + + def receive_gradients(self, gradients): + """ + Receive gradients from the active party and update parameters of the local bottom model. + + Args: + gradients (List[Tensor]): Gradients from the active party. + """ + self.common_grad = gradients + self._fit(self.X, self.y) + + def send_components(self): + """ + Send latent representation to the active party. + """ + result = self._forward_computation(self.X) + return result + + def predict(self, X, is_attack=False): + """ + Compute the output. + + Args: + X (Tensor): Input data. + is_attack (bool): Indicates whether the computation is for an attack scenario or not. + + Returns: + Tensor: Embeddings to be sent to the active party. + """ + return self._forward_computation(X) + + def save(self): + """ + Save model to local file. + """ + self.bottom_model.save(id=self.id, time=self.args['file_time']) + + def load(self, load_attack=False): + """ + Load the model from a local file. + + Args: + load_attack (bool): A flag indicating whether to load the attack model, marked as invalid in this context. + """ + if load_attack: + self.bottom_model.load(name='attack', time=self.args['load_time']) + else: + self.bottom_model.load(id=self.id, time=self.args['load_time']) + + def set_train(self): + """ + Set train mode. + """ + self.bottom_model.set_train(True) + + def set_eval(self): + """ + Set eval mode. + """ + self.bottom_model.set_train(False) + + def scheduler_step(self): + """ + Adjust learning rate during training. + """ + if self.bottom_model.scheduler is not None: + self.bottom_model.scheduler.step() + + def zero_grad(self): + """ + Clear gradients. + """ + pass diff --git a/examples/community/vfl_attacks/utils/vfl/init_vfl.py b/examples/community/vfl_attacks/utils/vfl/init_vfl.py new file mode 100644 index 0000000..5b5ed24 --- /dev/null +++ b/examples/community/vfl_attacks/utils/vfl/init_vfl.py @@ -0,0 +1,88 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Initialization of Vertical Federated Learning (VFL) participants, including models and participant types. + +This module handles the initialization of participants in a Vertical Federated Learning (VFL) setup, +defining the models used by each participant and specifying the type of each participant. +""" +from mindspore import ops +from party.active_party import VFLActiveModel +from party.passive_party import VFLPassiveModel +from methods.g_r.g_r_passive_party import GRPassiveModel +from methods.direct_attack.direct_attack_passive_party import DirectAttackPassiveModel + +__all__ = ["Init_Vfl"] + +class Init_Vfl(object): + """ + Initialize the vfl. + """ + def __init__(self, args): + super(Init_Vfl, self).__init__() + self.args = args + + def get_vfl(self, bottoms, active, train_dl, backdoor_target_indices=None, backdoor_indices=None): + """ + Generate the VFL system and set the parties. + + Args: + args (dict): Configuration for the VFL system. + backdoor_indices (List[int]): Indices of backdoor samples in the normal train dataset, + used for gradient replacement. + backdoor_target_indices (List[int]): Indices of samples labeled as the backdoor + class in the normal train dataset. + """ + self.traindl = train_dl + self.backdoor_target_indices = backdoor_target_indices + self.backdoor_indices = backdoor_indices + + active_bottom_model = bottoms[0] + party_model_list = [] + for i in range(0, self.args['n_passive_party']): + passive_party_model = bottoms[i+1] + party_model_list.append(passive_party_model) + + active_top_model = None + if self.args['active_top_trainable']: + active_top_model = active + + active_party = VFLActiveModel(bottom_model=active_bottom_model, + args=self.args, + top_model=active_top_model) + + self.party_list = [active_party] + for i, model in enumerate(party_model_list): + if self.args['backdoor'] == 'g_r' and i == self.args['adversary'] - 1: + passive_party = GRPassiveModel(bottom_model=model, + amplify_ratio=self.args['amplify_ratio'], id=i, args=self.args) + backdoor_X = {} + if self.traindl is not None: + for X, _, _, indices in self.traindl: + temp_indices = list(set(self.backdoor_indices) & set(indices.tolist())) + if len(temp_indices) > 0: + if self.args['n_passive_party'] < 2: + X = ops.transpose(X, (1, 0, 2, 3, 4)) + _, Xb_batch = X + else: + Xb_batch = X[:, self.args['adversary']:self.args['adversary']+1].squeeze(1) + for temp in temp_indices: + backdoor_X[temp] = Xb_batch[indices.tolist().index(temp)] + passive_party.set_backdoor_indices(self.backdoor_target_indices, self.backdoor_indices, backdoor_X) + elif self.args['label_inference_attack'] == 'direct_attack' and i == self.args['adversary'] - 1: + passive_party = DirectAttackPassiveModel(bottom_model=model, id=i, args=self.args) + else: + passive_party = VFLPassiveModel(bottom_model=model, id=i, args=self.args) + self.party_list.append(passive_party) diff --git a/examples/community/vfl_attacks/utils/vfl/vfl.py b/examples/community/vfl_attacks/utils/vfl/vfl.py new file mode 100644 index 0000000..fd39c39 --- /dev/null +++ b/examples/community/vfl_attacks/utils/vfl/vfl.py @@ -0,0 +1,344 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +This module defines the processes of data distribution, training, and prediction in VFL. +""" +import numpy as np +from sklearn.metrics import accuracy_score, top_k_accuracy_score +from mindspore import ops + +__all__ = ["VFL"] + +class VFL(object): + """ + VFL system. + """ + def __init__(self,train_loader, test_loader, init_vfl, backdoor_test_loader=None): + super(VFL,self).__init__() + self.active_party = init_vfl.party_list[0] + self.party_dict = {} + self.party_ids = [] + self.args = init_vfl.args + for index, party in enumerate(init_vfl.party_list[1:]): + self.add_party(id=index, party_model=party) + self.train_loader = train_loader + self.test_loader = test_loader + self.backdoor_test_loader = backdoor_test_loader + self.top_k = self.args['topk'] + self.state = None + self.is_attack = False + self.adversary = self.args['adversary'] - 1 + self.record_loss = [] + self.record_train_acc = [] + self.record_test_acc = [] + self.record_attack_metric = [] + self.record_results = [] + + def add_party(self, *, id, party_model): + """ + Add a passive party to the VFL system. + + Args: + id (int): The identifier for the passive party. + party_model (VFLPassiveModel): The passive party model to be added. + """ + self.party_dict[id] = party_model + self.party_ids.append(id) + + def set_current_epoch(self, ep): + """ + Set the current train epoch. + + Args: + ep (int): The current train epoch. + """ + self.active_party.set_epoch(ep) + for i in self.party_ids: + self.party_dict[i].set_epoch(ep) + + + def set_state(self, type): + """ + Set the state of the system. + + Args: + type (str): The state type, which can be 'train', 'test', or 'attack'. + """ + self.state = type + if type == 'attack': + self.is_attack = True + else: + self.is_attack = False + + def train(self): + """ + Load or train the vfl model. + """ + if self.args['load_model']: + self.load() + self.set_state('test') + self.train_acc = self.predict(self.train_loader, num_classes=self.args['num_classes'],top_k=self.top_k, + n_passive_party=self.args['n_passive_party']) + self.set_state('test') + self.test_acc = self.predict(self.test_loader, num_classes=self.args['num_classes'],top_k=self.top_k, + n_passive_party=self.args['n_passive_party']) + if self.backdoor_test_loader is not None: + self.set_state('attack') + self.backdoor_acc = self.predict(self.backdoor_test_loader, num_classes=self.args['num_classes'], + top_k=self.top_k, n_passive_party=self.args['n_passive_party']) + self.set_state('test') + print(f'train_acc: {self.train_acc}, test_acc: {self.test_acc}, backdoor_acc:{self.backdoor_acc}') + else: + print(f'train_acc: {self.train_acc}, test_acc: {self.test_acc}') + else: + for ep in range(self.args['target_epochs']): + loss_list = [] + self.set_train() + self.set_current_epoch(ep) + + for batch_idx, (X, Y_batch, old_imgb, indices) in enumerate(self.train_loader): + party_X_train_batch_dict = {} + if self.args['n_passive_party'] < 2: + X = ops.transpose(X, (1, 0, 2, 3, 4)) + active_X_batch, Xb_batch = X + party_X_train_batch_dict[0] = Xb_batch + else: + active_X_batch = X[:, 0:1].squeeze(1) + for i in range(self.args['n_passive_party']): + party_X_train_batch_dict[i] = X[:, i + 1:i + 2].squeeze(1) + loss, grad_list = self.fit(active_X_batch, Y_batch, party_X_train_batch_dict, indices) + loss_list.append(loss) + self.scheduler_step() + + # Compute main-task accuracy. + ave_loss = np.sum(loss_list) / len(self.train_loader.children[0]) + self.set_state('train') + self.train_acc = self.predict(self.train_loader, num_classes=self.args['num_classes'],top_k=self.top_k, + n_passive_party=self.args['n_passive_party']) + self.set_state('test') + self.test_acc = self.predict(self.test_loader, num_classes=self.args['num_classes'],top_k=self.top_k, + n_passive_party=self.args['n_passive_party']) + self.record_train_acc.append(self.train_acc) + self.record_test_acc.append(self.test_acc) + self.record_loss.append(ave_loss) + # Compute backdoor task accuracy. + if self.backdoor_test_loader is not None: + self.set_state('attack') + self.backdoor_acc = self.predict(self.backdoor_test_loader, + num_classes=self.args['num_classes'], top_k=self.top_k, + n_passive_party=self.args['n_passive_party']) + self.set_state('test') + print( + f"--- epoch: {ep}, train loss: {ave_loss}, train_acc: {self.train_acc * 100}%, " + f"test acc: {self.test_acc * 100}%, backdoor acc: {self.backdoor_acc * 100}%") + self.record_attack_metric.append(self.backdoor_acc) + else: + print( + f"--- epoch: {ep}, train loss: {ave_loss}, train_acc: {self.train_acc * 100}%, " + f"test acc: {self.test_acc * 100}%") + + def fit(self, active_X, y, party_X_dict, indices): + """ + Perform VFL training for one batch. + + Args: + active_X (Tensor): Features of the active party. + y (Tensor): Labels for the current batch. + party_X_dict (dict): Features of passive parties, with party IDs as keys. + indices (List[int]): Indices of samples in the current batch. + + Returns: + tuple: A tuple containing the loss computed for the batch and the gradients returned by the bottom models. + The gradients are marked as invalid in normal training. + """ + # Set features and labels for active party. + self.active_party.set_batch(active_X, y) + self.active_party.set_indices(indices) + + # Set features for all passive parties. + for idx, party_X in party_X_dict.items(): + self.party_dict[idx].set_batch(party_X, indices) + + # All passive parties output latent representations and upload them to active party. + comp_list = [] + for id in self.party_ids: + party = self.party_dict[id] + logits = party.send_components() + comp_list.append(logits) + self.active_party.receive_components(component_list=comp_list) + + # Active party compute gradients based on labels and update parameters of its bottom model and top model. + self.active_party.fit() + loss = self.active_party.get_loss() + + # Active party send gradients to passive parties, then passive parties update their bottom models. + parties_grad_list = self.active_party.send_gradients() + grad_list = [] + for index, id in enumerate(self.party_ids): + party = self.party_dict[id] + grad = party.receive_gradients(parties_grad_list[index]) + grad_list.append(grad) + + return loss, grad_list + + def save(self): + """ + Save all models in VFL, including top model and all bottom models. + """ + self.active_party.save() + for id in self.party_ids: + self.party_dict[id].save() + + def load(self, load_attack=False): + """ + Load all models in the VFL system, including the top model and all bottom models. + + Args: + load_attack (bool): invalid. + """ + self.active_party.load() + for id in self.party_ids: + if load_attack and id == 0: + self.party_dict[id].load(load_attack=True) + else: + self.party_dict[id].load() + + def predict(self, test_loader, num_classes, top_k=1, n_passive_party=2): + """ + Compute the accuracy of the VFL system on the test dataset. + + Args: + test_loader (DataLoader): Loader for the test dataset. + num_classes (int): Number of classes in the dataset. + dataset (str): Name of the dataset. + top_k (int): Top-k value for accuracy computation. + n_passive_party (int): Number of passive parties in the VFL system. + is_attack (bool): Whether to compute attack accuracy. + + Returns: + float: The computed accuracy of the VFL system. + """ + y_predict = [] + y_true = [] + + self.set_eval() + + for batch_idx, (X, targets, old_imgb, indices) in enumerate(test_loader): + party_X_test_dict = {} + if self.args['n_passive_party'] < 2: + X = ops.transpose(X, (1, 0, 2, 3, 4)) + active_X_inputs, Xb_inputs = X + party_X_test_dict[0] = Xb_inputs + else: + active_X_inputs = X[:, 0:1].squeeze(1) + for i in range(n_passive_party): + party_X_test_dict[i] = X[:, i + 1:i + 2].squeeze(1) + y_true += targets.tolist() + + self.active_party.indices = indices + + y_prob_preds = self.batch_predict(active_X_inputs, party_X_test_dict) + y_predict += y_prob_preds.tolist() + + acc = self.accuracy(y_true, y_predict, top_k=top_k, num_classes=num_classes) + return acc + + def write(self): + ''' + Save the models. + ''' + if self.args['save_model']: + self.save() + + def batch_predict(self, active_X, party_X_dict): + """ + Predict labels with help of all parties. + + Args: + active_X (Tensor): Features of the active party. + party_X_dict (dict): Features of passive parties, with party IDs as keys. + attack_output (Tensor, optional): Latent representation output by the attacker if provided. + is_attack (bool): Whether the prediction process is for an attack scenario. + + Returns: + Tensor: The prediction labels. + """ + comp_list = [] + # Passive parties send latent representations + for id in self.party_ids: + comp_list.append(self.party_dict[id].predict(party_X_dict[id], self.is_attack)) + + # Active party make the final prediction + return self.active_party.predict(active_X, component_list=comp_list, type=self.state) + + def set_train(self): + """ + Set train mode for all parties. + """ + self.active_party.set_train() + for id in self.party_ids: + self.party_dict[id].set_train() + + def set_eval(self): + """ + Set eval mode for all parties. + """ + self.active_party.set_eval() + for id in self.party_ids: + self.party_dict[id].set_eval() + + def scheduler_step(self): + """ + Adjust learning rate for all parties during training. + """ + self.active_party.scheduler_step() + for id in self.party_ids: + self.party_dict[id].scheduler_step() + + def zero_grad(self): + """ + Clear gradients for all parties. + """ + self.active_party.zero_grad() + for id in self.party_ids: + self.party_dict[id].zero_grad() + + def accuracy(self, y_true, y_pred, num_classes=None, top_k=1): + """ + Compute model accuracy. + + Args: + y_true (list): List of ground-truth labels. + y_pred (list): List of prediction labels. + dataset (str): Name of the dataset. + num_classes (int): Number of classes in the dataset. + top_k (int, optional): Top-k value for accuracy computation. Default is 1. + is_attack (bool, optional): Whether to compute accuracy for attack scenarios. + + Returns: + float: The computed model accuracy. + """ + y_pred = np.array(y_pred) + if np.any(np.isnan(y_pred)) or not np.all(np.isfinite(y_pred)): + raise ValueError('accuracy y_pred is isnan') + temp_y_pred = [] + if top_k == 1: + for pred in y_pred: + temp = np.max(pred) + temp_y_pred.append(np.where(pred == temp)[0][0]) + acc = accuracy_score(y_true, temp_y_pred) + else: + acc = top_k_accuracy_score(y_true, y_pred, k=top_k, labels=np.arange(num_classes)) + return acc