-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add demo for vertical federated learning (#9103)
- Loading branch information
Showing
11 changed files
with
360 additions
and
102 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Experimental Support of Horizontal Federated XGBoost using NVFlare | ||
|
||
This directory contains a demo of Horizontal Federated Learning using | ||
[NVFlare](https://nvidia.github.io/NVFlare/). | ||
|
||
## Training with CPU only | ||
|
||
To run the demo, first build XGBoost with the federated learning plugin enabled (see the | ||
[README](../../plugin/federated/README.md)). | ||
|
||
Install NVFlare (note that currently NVFlare only supports Python 3.8): | ||
```shell | ||
pip install nvflare | ||
``` | ||
|
||
Prepare the data: | ||
```shell | ||
./prepare_data.sh | ||
``` | ||
|
||
Start the NVFlare federated server: | ||
```shell | ||
/tmp/nvflare/poc/server/startup/start.sh | ||
``` | ||
|
||
In another terminal, start the first worker: | ||
```shell | ||
/tmp/nvflare/poc/site-1/startup/start.sh | ||
``` | ||
|
||
And the second worker: | ||
```shell | ||
/tmp/nvflare/poc/site-2/startup/start.sh | ||
``` | ||
|
||
Then start the admin CLI: | ||
```shell | ||
/tmp/nvflare/poc/admin/startup/fl_admin.sh | ||
``` | ||
|
||
In the admin CLI, run the following command: | ||
```shell | ||
submit_job horizontal-xgboost | ||
``` | ||
|
||
Once the training finishes, the model file should be written into | ||
`/tmp/nvlfare/poc/site-1/run_1/test.model.json` and `/tmp/nvflare/poc/site-2/run_1/test.model.json` | ||
respectively. | ||
|
||
Finally, shutdown everything from the admin CLI, using `admin` as password: | ||
```shell | ||
shutdown client | ||
shutdown server | ||
``` | ||
|
||
## Training with GPUs | ||
|
||
To demo with Federated Learning using GPUs, make sure your machine has at least 2 GPUs. | ||
Build XGBoost with the federated learning plugin enabled along with CUDA, but with NCCL | ||
turned off (see the [README](../../plugin/federated/README.md)). | ||
|
||
Modify `config/config_fed_client.json` and set `use_gpus` to `true`, then repeat the steps | ||
above. |
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Experimental Support of Vertical Federated XGBoost using NVFlare | ||
|
||
This directory contains a demo of Vertical Federated Learning using | ||
[NVFlare](https://nvidia.github.io/NVFlare/). | ||
|
||
## Training with CPU only | ||
|
||
To run the demo, first build XGBoost with the federated learning plugin enabled (see the | ||
[README](../../plugin/federated/README.md)). | ||
|
||
Install NVFlare (note that currently NVFlare only supports Python 3.8): | ||
```shell | ||
pip install nvflare | ||
``` | ||
|
||
Prepare the data (note that this step will download the HIGGS dataset, which is 2.6GB compressed, and 7.5GB | ||
uncompressed, so make sure you have enough disk space and are on a fast internet connection): | ||
```shell | ||
./prepare_data.sh | ||
``` | ||
|
||
Start the NVFlare federated server: | ||
```shell | ||
/tmp/nvflare/poc/server/startup/start.sh | ||
``` | ||
|
||
In another terminal, start the first worker: | ||
```shell | ||
/tmp/nvflare/poc/site-1/startup/start.sh | ||
``` | ||
|
||
And the second worker: | ||
```shell | ||
/tmp/nvflare/poc/site-2/startup/start.sh | ||
``` | ||
|
||
Then start the admin CLI: | ||
```shell | ||
/tmp/nvflare/poc/admin/startup/fl_admin.sh | ||
``` | ||
|
||
In the admin CLI, run the following command: | ||
```shell | ||
submit_job vertical-xgboost | ||
``` | ||
|
||
Once the training finishes, the model file should be written into | ||
`/tmp/nvlfare/poc/site-1/run_1/test.model.json` and `/tmp/nvflare/poc/site-2/run_1/test.model.json` | ||
respectively. | ||
|
||
Finally, shutdown everything from the admin CLI, using `admin` as password: | ||
```shell | ||
shutdown client | ||
shutdown server | ||
``` | ||
|
||
## Training with GPUs | ||
|
||
Currently GPUs are not yet supported by vertical federated XGBoost. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
""" | ||
Example of training controller with NVFlare | ||
=========================================== | ||
""" | ||
import multiprocessing | ||
|
||
from nvflare.apis.client import Client | ||
from nvflare.apis.fl_context import FLContext | ||
from nvflare.apis.impl.controller import Controller, Task | ||
from nvflare.apis.shareable import Shareable | ||
from nvflare.apis.signal import Signal | ||
from trainer import SupportedTasks | ||
|
||
import xgboost.federated | ||
|
||
|
||
class XGBoostController(Controller): | ||
def __init__(self, port: int, world_size: int, server_key_path: str, | ||
server_cert_path: str, client_cert_path: str): | ||
"""Controller for federated XGBoost. | ||
Args: | ||
port: the port for the gRPC server to listen on. | ||
world_size: the number of sites. | ||
server_key_path: the path to the server key file. | ||
server_cert_path: the path to the server certificate file. | ||
client_cert_path: the path to the client certificate file. | ||
""" | ||
super().__init__() | ||
self._port = port | ||
self._world_size = world_size | ||
self._server_key_path = server_key_path | ||
self._server_cert_path = server_cert_path | ||
self._client_cert_path = client_cert_path | ||
self._server = None | ||
|
||
def start_controller(self, fl_ctx: FLContext): | ||
self._server = multiprocessing.Process( | ||
target=xgboost.federated.run_federated_server, | ||
args=(self._port, self._world_size, self._server_key_path, | ||
self._server_cert_path, self._client_cert_path)) | ||
self._server.start() | ||
|
||
def stop_controller(self, fl_ctx: FLContext): | ||
if self._server: | ||
self._server.terminate() | ||
|
||
def process_result_of_unknown_task(self, client: Client, task_name: str, | ||
client_task_id: str, result: Shareable, | ||
fl_ctx: FLContext): | ||
self.log_warning(fl_ctx, f"Unknown task: {task_name} from client {client.name}.") | ||
|
||
def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): | ||
self.log_info(fl_ctx, "XGBoost training control flow started.") | ||
if abort_signal.triggered: | ||
return | ||
task = Task(name=SupportedTasks.TRAIN, data=Shareable()) | ||
self.broadcast_and_wait( | ||
task=task, | ||
min_responses=self._world_size, | ||
fl_ctx=fl_ctx, | ||
wait_time_after_min_received=1, | ||
abort_signal=abort_signal, | ||
) | ||
if abort_signal.triggered: | ||
return | ||
|
||
self.log_info(fl_ctx, "XGBoost training control flow finished.") |
Oops, something went wrong.