Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) Add run_config to ClientApp Context #3751

Merged
merged 39 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
dc4c5b7
feat(framework:skip) Add override to context and utility funcion
charlesbvll Jul 8, 2024
110031d
Fix docstring formatting
charlesbvll Jul 8, 2024
563bfeb
Remove non parameter from docstring
charlesbvll Jul 8, 2024
c38e0dd
Split function calls to multiple lines
charlesbvll Jul 8, 2024
ff10f09
Return empty dict if config doesn't exist
charlesbvll Jul 8, 2024
a4478cc
Add test for fusing dicts
charlesbvll Jul 8, 2024
bffa910
Remove ability to add new values
charlesbvll Jul 8, 2024
0010d2e
Remove unused imports
charlesbvll Jul 8, 2024
592c6d4
Merge branch 'main' into add-override-common
charlesbvll Jul 8, 2024
ec48536
feat(framework) Add `run_config` to `ClientApp` `Context`
charlesbvll Jul 8, 2024
0ef7d36
Merge branch 'add-override-common' into add-override-clientapp
charlesbvll Jul 8, 2024
a971117
Merge branch 'main' into add-override-clientapp
charlesbvll Jul 8, 2024
4baef0c
Merge branch 'main' into add-override-clientapp
charlesbvll Jul 8, 2024
c1b3b16
Merge branch 'main' into add-override-clientapp
charlesbvll Jul 9, 2024
dd22377
Merge branch 'main' into add-override-clientapp
danieljanes Jul 9, 2024
de63171
refactor(framework:skip) Add run_config as required parameter to Context
charlesbvll Jul 9, 2024
951f08a
Merge branch 'add-run-config-context' into add-override-clientapp
charlesbvll Jul 9, 2024
a252dc5
Only add config on first start
charlesbvll Jul 9, 2024
f71a1a9
Remove unused import
charlesbvll Jul 9, 2024
f398fe9
Merge branch 'main' into add-override-clientapp
charlesbvll Jul 9, 2024
13cd1ab
Merge branch 'main' into add-override-clientapp
danieljanes Jul 9, 2024
d1e1557
Merge branch 'main' into add-override-clientapp
charlesbvll Jul 10, 2024
248069a
Make run_config unmodifiable
charlesbvll Jul 10, 2024
4b75feb
Merge branch 'main' into add-override-clientapp
charlesbvll Jul 10, 2024
648d414
Fix node_state
charlesbvll Jul 10, 2024
c0a7c33
Add type annotation
charlesbvll Jul 10, 2024
a9bfc70
Merge branch 'main' into add-override-clientapp
danieljanes Jul 11, 2024
a64379e
Fix node_state
charlesbvll Jul 11, 2024
5209b16
Fix node state
charlesbvll Jul 11, 2024
df59133
Add RunInfo dataclass
charlesbvll Jul 11, 2024
1962c4a
Fix test
charlesbvll Jul 11, 2024
7996a61
Merge branch 'main' into add-override-clientapp
charlesbvll Jul 11, 2024
1102bd3
Merge branch 'main' into add-override-clientapp
charlesbvll Jul 11, 2024
7f7aad1
Improve naming
charlesbvll Jul 11, 2024
f2a8123
Merge branch 'add-override-clientapp' of https://github.com/adap/flow…
charlesbvll Jul 11, 2024
f25762d
Merge branch 'main' into add-override-clientapp
charlesbvll Jul 11, 2024
6ea228b
Fix CI
charlesbvll Jul 11, 2024
185fd4a
Rename run_info
charlesbvll Jul 11, 2024
06ecf9a
Rename run_info to run_infos
charlesbvll Jul 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import time
from dataclasses import dataclass
from logging import DEBUG, ERROR, INFO, WARN
from pathlib import Path
from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union

from cryptography.hazmat.primitives.asymmetric import ec
Expand Down Expand Up @@ -193,6 +194,7 @@ def _start_client_internal(
max_retries: Optional[int] = None,
max_wait_time: Optional[float] = None,
partition_id: Optional[int] = None,
flwr_dir: Optional[Path] = None,
) -> None:
"""Start a Flower client node which connects to a Flower server.

Expand Down Expand Up @@ -239,6 +241,8 @@ class `flwr.client.Client` (default: None)
partition_id: Optional[int] (default: None)
The data partition index associated with this node. Better suited for
prototyping purposes.
flwr_dir: Optional[Path] (default: None)
The fully resolved path containing installed Flower Apps.
"""
if insecure is None:
insecure = root_certificates is None
Expand Down Expand Up @@ -316,7 +320,7 @@ def _on_backoff(retry_state: RetryState) -> None:
)

node_state = NodeState(partition_id=partition_id)
run_info: Dict[int, Run] = {}
runs: Dict[int, Run] = {}

while not app_state_tracker.interrupt:
sleep_duration: int = 0
Expand Down Expand Up @@ -366,15 +370,17 @@ def _on_backoff(retry_state: RetryState) -> None:

# Get run info
run_id = message.metadata.run_id
if run_id not in run_info:
if run_id not in runs:
if get_run is not None:
run_info[run_id] = get_run(run_id)
runs[run_id] = get_run(run_id)
# If get_run is None, i.e., in grpc-bidi mode
else:
run_info[run_id] = Run(run_id, "", "", {})
runs[run_id] = Run(run_id, "", "", {})

# Register context for this run
node_state.register_context(run_id=run_id)
node_state.register_context(
run_id=run_id, run=runs[run_id], flwr_dir=flwr_dir
)

# Retrieve context for this run
context = node_state.retrieve_context(run_id=run_id)
Expand All @@ -388,7 +394,7 @@ def _on_backoff(retry_state: RetryState) -> None:
# Handle app loading and task message
try:
# Load ClientApp instance
run: Run = run_info[run_id]
run: Run = runs[run_id]
client_app: ClientApp = load_client_app_fn(
run.fab_id, run.fab_version
)
Expand Down
44 changes: 36 additions & 8 deletions src/py/flwr/client/node_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,53 @@
"""Node state."""


from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Optional

from flwr.common import Context, RecordSet
from flwr.common.config import get_fused_config
from flwr.common.typing import Run


@dataclass()
class RunInfo:
"""Contains the Context and initial run_config of a Run."""

context: Context
initial_run_config: Dict[str, str]


class NodeState:
"""State of a node where client nodes execute runs."""

def __init__(self, partition_id: Optional[int]) -> None:
self._meta: Dict[str, Any] = {} # holds metadata about the node
self.run_contexts: Dict[int, Context] = {}
self.run_infos: Dict[int, RunInfo] = {}
self._partition_id = partition_id

def register_context(self, run_id: int) -> None:
def register_context(
self,
run_id: int,
run: Optional[Run] = None,
flwr_dir: Optional[Path] = None,
) -> None:
"""Register new run context for this node."""
if run_id not in self.run_contexts:
self.run_contexts[run_id] = Context(
state=RecordSet(), run_config={}, partition_id=self._partition_id
if run_id not in self.run_infos:
initial_run_config = get_fused_config(run, flwr_dir) if run else {}
self.run_infos[run_id] = RunInfo(
initial_run_config=initial_run_config,
context=Context(
state=RecordSet(),
run_config=initial_run_config.copy(),
partition_id=self._partition_id,
),
)

def retrieve_context(self, run_id: int) -> Context:
"""Get run context given a run_id."""
if run_id in self.run_contexts:
return self.run_contexts[run_id]
if run_id in self.run_infos:
return self.run_infos[run_id].context

raise RuntimeError(
f"Context for run_id={run_id} doesn't exist."
Expand All @@ -48,4 +71,9 @@ def retrieve_context(self, run_id: int) -> Context:

def update_context(self, run_id: int, context: Context) -> None:
"""Update run context."""
self.run_contexts[run_id] = context
if context.run_config != self.run_infos[run_id].initial_run_config:
raise ValueError(
"The `run_config` field of the `Context` object cannot be "
f"modified (run_id: {run_id})."
)
self.run_infos[run_id].context = context
5 changes: 3 additions & 2 deletions src/py/flwr/client/node_state_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def test_multirun_in_node_state() -> None:
node_state.update_context(run_id=run_id, context=updated_state)

# Verify values
for run_id, context in node_state.run_contexts.items():
for run_id, run_info in node_state.run_infos.items():
assert (
context.state.configs_records["counter"]["count"] == expected_values[run_id]
run_info.context.state.configs_records["counter"]["count"]
== expected_values[run_id]
)
1 change: 1 addition & 0 deletions src/py/flwr/client/supernode/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def run_supernode() -> None:
max_retries=args.max_retries,
max_wait_time=args.max_wait_time,
partition_id=args.partition_id,
flwr_dir=get_flwr_dir(args.flwr_dir),
)

# Graceful shutdown
Expand Down