Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) Enable setting run_id when starting simulation #3576

Merged
merged 7 commits into from
Jun 14, 2024
22 changes: 22 additions & 0 deletions src/py/flwr/simulation/run_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def run_simulation_from_cli() -> None:
backend_name=args.backend,
backend_config=backend_config_dict,
app_dir=args.app_dir,
run_id=args.run_id,
enable_tf_gpu_growth=args.enable_tf_gpu_growth,
verbose_logging=args.verbose,
)
Expand Down Expand Up @@ -168,13 +169,21 @@ def server_th_with_start_checks( # type: ignore
return serverapp_th


def _init_run_id(driver: InMemoryDriver, state: StateFactory, run_id: int) -> None:
"""Create a run with a given `run_id`."""
log(DEBUG, "Pre-registering run with id %s", run_id)
state.state().run_ids[run_id] = ("", "") # type: ignore
driver.run_id = run_id


# pylint: disable=too-many-locals
def _main_loop(
num_supernodes: int,
backend_name: str,
backend_config_stream: str,
app_dir: str,
enable_tf_gpu_growth: bool,
run_id: Optional[int] = None,
client_app: Optional[ClientApp] = None,
client_app_attr: Optional[str] = None,
server_app: Optional[ServerApp] = None,
Expand All @@ -195,6 +204,9 @@ def _main_loop(
# Initialize Driver
driver = InMemoryDriver(state_factory)

if run_id:
_init_run_id(driver, state_factory, run_id)

# Get and run ServerApp thread
serverapp_th = run_serverapp_th(
server_app_attr=server_app_attr,
Expand Down Expand Up @@ -244,6 +256,7 @@ def _run_simulation(
client_app_attr: Optional[str] = None,
server_app_attr: Optional[str] = None,
app_dir: str = "",
run_id: Optional[int] = None,
enable_tf_gpu_growth: bool = False,
verbose_logging: bool = False,
) -> None:
Expand Down Expand Up @@ -283,6 +296,9 @@ def _run_simulation(
Add specified directory to the PYTHONPATH and load `ClientApp` from there.
(Default: current working directory.)

run_id : Optional[int]
An integer specifying the ID of the run started when running this function.

enable_tf_gpu_growth : bool (default: False)
A boolean to indicate whether to enable GPU growth on the main thread. This is
desirable if you make use of a TensorFlow model on your `ServerApp` while
Expand Down Expand Up @@ -322,6 +338,7 @@ def _run_simulation(
backend_config_stream,
app_dir,
enable_tf_gpu_growth,
run_id,
client_app,
client_app_attr,
server_app,
Expand Down Expand Up @@ -413,5 +430,10 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser:
"ClientApp and ServerApp from there."
" Default: current working directory.",
)
parser.add_argument(
"--run-id",
type=int,
help="Sets the id of the run started by the Simulation Engine.",
danieljanes marked this conversation as resolved.
Show resolved Hide resolved
)

return parser
Loading