-
Notifications
You must be signed in to change notification settings - Fork 223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add protect method for feature order in fl-xgb #497
Changes from 20 commits
0d47076
ef1c0cb
d93ba58
57c80b3
e870236
f3ccf5f
e1af4f4
a6c1a89
01326e8
fd38282
a1e05d9
159a6c6
76846a4
a14e73b
9876bd9
86225b2
92d3d03
c3b5653
d513855
6c02215
0623b71
0b1c600
0a6a265
3f156fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from federatedscope.vertical_fl.trainer.trainer import VerticalTrainer | ||
from federatedscope.vertical_fl.trainer.feature_order_protected_trainer \ | ||
import FeatureOrderProtectedTrainer |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import numpy as np | ||
from federatedscope.vertical_fl.trainer.trainer import VerticalTrainer | ||
|
||
|
||
class FeatureOrderProtectedTrainer(VerticalTrainer): | ||
def __init__(self, model, data, device, config, monitor): | ||
super(FeatureOrderProtectedTrainer, | ||
self).__init__(model, data, device, config, monitor) | ||
|
||
assert config.vertical.protect_method != '', \ | ||
"Please specify the adopted method for protecting feature order" | ||
args = config.vertical.protect_args[0] if len( | ||
config.vertical.protect_args) > 0 else {} | ||
|
||
if config.vertical.protect_method == 'use_bins': | ||
self.bin_num = args.get('bin_num', 100) | ||
self.share_bin = args.get('share_bin', False) | ||
self.protect_funcs = self._protect_via_bins | ||
self.split_value = None | ||
else: | ||
raise ValueError(f"The method {args['method']} is not provided") | ||
|
||
def get_feature_value(self, feature_idx, value_idx): | ||
assert self.split_value is not None | ||
|
||
return self.split_value[feature_idx][value_idx] | ||
|
||
def _protect_via_bins(self, raw_feature_order, data): | ||
protected_feature_order = list() | ||
bin_size = int(np.ceil(self.cfg.dataloader.batch_size / self.bin_num)) | ||
split_position = [[] for _ in range(len(raw_feature_order)) | ||
] if self.share_bin else None | ||
self.split_value = [dict() for _ in range(len(raw_feature_order))] | ||
for i in range(len(raw_feature_order)): | ||
_protected_feature_order = list() | ||
for j in range(self.bin_num): | ||
idx_start = j * bin_size | ||
idx_end = min((j + 1) * bin_size, len(raw_feature_order[i])) | ||
feature_order_frame = raw_feature_order[i][idx_start:idx_end] | ||
np.random.shuffle(feature_order_frame) | ||
_protected_feature_order.append(feature_order_frame) | ||
if self.share_bin: | ||
if j != self.bin_num - 1: | ||
split_position[i].append(idx_end) | ||
min_value = min(data[feature_order_frame][:, i]) | ||
max_value = max(data[feature_order_frame][:, i]) | ||
if j == 0: | ||
self.split_value[i][idx_end] = max_value | ||
elif j == self.bin_num - 1: | ||
self.split_value[i][idx_start] += min_value / 2.0 | ||
else: | ||
self.split_value[i][idx_start] += min_value / 2.0 | ||
self.split_value[i][idx_end] = max_value / 2.0 | ||
else: | ||
mean_value = np.mean(data[feature_order_frame][:, i]) | ||
for x in range(idx_start, idx_end): | ||
self.split_value[i][x] = mean_value | ||
protected_feature_order.append( | ||
np.concatenate(_protected_feature_order)) | ||
|
||
extra_info = None | ||
if split_position is not None: | ||
extra_info = {'split_position': split_position} | ||
|
||
return { | ||
'feature_order': protected_feature_order, | ||
'extra_info': extra_info | ||
} | ||
|
||
def _get_feature_order_info(self, data): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For convenience, I also protected the label owner's feature order before. Actually, label owner does not need to do this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since it need some more efforts to fix this issue such as modifying the split position accordingly, we can add TODO item here and fix it later |
||
num_of_feature = data.shape[1] | ||
feature_order = [0] * num_of_feature | ||
for i in range(num_of_feature): | ||
feature_order[i] = data[:, i].argsort() | ||
return self.protect_funcs(feature_order, data) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from federatedscope.vertical_fl.trainer import VerticalTrainer, \ | ||
FeatureOrderProtectedTrainer | ||
|
||
|
||
def get_vertical_trainer(config, model, data, device, monitor): | ||
|
||
protect_object = config.vertical.protect_object | ||
if not protect_object or protect_object == '': | ||
return VerticalTrainer(model=model, | ||
data=data, | ||
device=device, | ||
config=config, | ||
monitor=monitor) | ||
elif protect_object == 'feature_order': | ||
return FeatureOrderProtectedTrainer(model=model, | ||
data=data, | ||
device=device, | ||
config=config, | ||
monitor=monitor) | ||
else: | ||
raise ValueError |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
use_gpu: False | ||
device: 0 | ||
backend: torch | ||
federate: | ||
mode: standalone | ||
client_num: 2 | ||
model: | ||
type: xgb_tree | ||
lambda_: 0.1 | ||
gamma: 0 | ||
num_of_trees: 10 | ||
max_tree_depth: 3 | ||
data: | ||
root: data/ | ||
type: adult | ||
splits: [1.0, 0.0] | ||
dataloader: | ||
type: raw | ||
batch_size: 2000 | ||
criterion: | ||
type: CrossEntropyLoss | ||
trainer: | ||
type: verticaltrainer | ||
train: | ||
optimizer: | ||
bin_num: 100 | ||
# learning rate for xgb model | ||
eta: 0.5 | ||
vertical: | ||
use: True | ||
dims: [7, 14] | ||
algo: 'xgb' | ||
protect_object: 'feature_order' | ||
protect_method: 'use_bins' | ||
protect_args: [{'bin_num': 100, 'share_bin': True}] | ||
eval: | ||
freq: 3 | ||
best_res_update_round_wise_key: test_loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The value on the right hand side should be divided by 2.0