Skip to content

Commit

Permalink
Merge branch 'main' into amend-linkstate-keys
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 authored Jan 8, 2025
2 parents 0568a3b + 700a666 commit a8ea96e
Show file tree
Hide file tree
Showing 14 changed files with 167 additions and 135 deletions.
4 changes: 2 additions & 2 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
README.md @jafermarq @tanertopal @danieljanes

# Flower Baselines
/baselines @jafermarq @tanertopal @danieljanes
/baselines @chongshenng @jafermarq @tanertopal @danieljanes

# Flower Benchmarks
/benchmarks @jafermarq @tanertopal @danieljanes
Expand All @@ -16,7 +16,7 @@ README.md @jafermarq @tanertopal @danieljanes
/datasets @jafermarq @tanertopal @danieljanes

# Flower Examples
/examples @jafermarq @tanertopal @danieljanes
/examples @chongshenng @jafermarq @tanertopal @danieljanes

# Flower Templates
/src/py/flwr/cli/new/templates @jafermarq @tanertopal @danieljanes
Expand Down
4 changes: 2 additions & 2 deletions baselines/fjord/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ coloredlogs==15.0.1
hydra-core==1.3.2
flwr[simulation]==1.5.0
omegaconf==2.3.0
torch==2.0.1
torchvision==0.15.2
torch==2.2.0
torchvision==0.17.0
tqdm==4.65.0
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ requests = "^2.31.0"
# Optional dependencies (Simulation Engine)
ray = { version = "==2.10.0", optional = true, python = ">=3.9,<3.12" }
# Optional dependencies (REST transport layer)
starlette = { version = "^0.31.0", optional = true }
uvicorn = { version = "^0.23.0", extras = ["standard"], optional = true }
starlette = { version = "^0.45.2", optional = true }
uvicorn = { version = "^0.34.0", extras = ["standard"], optional = true }

[tool.poetry.extras]
simulation = ["ray"]
Expand Down
15 changes: 12 additions & 3 deletions src/proto/flwr/proto/exec.proto
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,19 @@ message ListRunsResponse {
}

message GetLoginDetailsRequest {}
message GetLoginDetailsResponse { map<string, string> login_details = 1; }
message GetLoginDetailsResponse {
string auth_type = 1;
string device_code = 2;
string verification_uri_complete = 3;
int64 expires_in = 4;
int64 interval = 5;
}

message GetAuthTokensRequest { map<string, string> auth_details = 1; }
message GetAuthTokensResponse { map<string, string> auth_tokens = 1; }
message GetAuthTokensRequest { string device_code = 1; }
message GetAuthTokensResponse {
string access_token = 1;
string refresh_token = 2;
}

message StopRunRequest { uint64 run_id = 1; }
message StopRunResponse { bool success = 1; }
8 changes: 6 additions & 2 deletions src/py/flwr/cli/cli_user_auth_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,12 @@ def _authenticated_call(

response = continuation(details, request)
if response.initial_metadata():
retrieved_metadata = dict(response.initial_metadata())
self.auth_plugin.store_tokens(retrieved_metadata)
credentials = self.auth_plugin.read_tokens_from_metadata(
response.initial_metadata()
)
# The metadata contains tokens only if they have been refreshed
if credentials is not None:
self.auth_plugin.store_tokens(credentials)

return response

Expand Down
15 changes: 11 additions & 4 deletions src/py/flwr/cli/login/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
process_loaded_project_config,
validate_federation_in_project_config,
)
from flwr.common.constant import AUTH_TYPE
from flwr.common.typing import UserAuthLoginDetails
from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
GetLoginDetailsRequest,
GetLoginDetailsResponse,
Expand Down Expand Up @@ -64,7 +64,7 @@ def login( # pylint: disable=R0914
login_response: GetLoginDetailsResponse = stub.GetLoginDetails(login_request)

# Get the auth plugin
auth_type = login_response.login_details.get(AUTH_TYPE)
auth_type = login_response.auth_type
auth_plugin = try_obtain_cli_auth_plugin(app, federation, auth_type)
if auth_plugin is None:
typer.secho(
Expand All @@ -75,7 +75,14 @@ def login( # pylint: disable=R0914
raise typer.Exit(code=1)

# Login
auth_config = auth_plugin.login(dict(login_response.login_details), stub)
details = UserAuthLoginDetails(
auth_type=login_response.auth_type,
device_code=login_response.device_code,
verification_uri_complete=login_response.verification_uri_complete,
expires_in=login_response.expires_in,
interval=login_response.interval,
)
credentials = auth_plugin.login(details, stub)

# Store the tokens
auth_plugin.store_tokens(auth_config)
auth_plugin.store_tokens(credentials)
8 changes: 4 additions & 4 deletions src/py/flwr/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,19 +223,19 @@ def try_obtain_cli_auth_plugin(
config_path = get_user_auth_config_path(root_dir, federation)

# Load the config file if it exists
config: dict[str, Any] = {}
json_file: dict[str, Any] = {}
if config_path.exists():
with config_path.open("r", encoding="utf-8") as file:
config = json.load(file)
json_file = json.load(file)
# This is the case when the user auth is not enabled
elif auth_type is None:
return None

# Get the auth type from the config if not provided
if auth_type is None:
if AUTH_TYPE not in config:
if AUTH_TYPE not in json_file:
return None
auth_type = config[AUTH_TYPE]
auth_type = json_file[AUTH_TYPE]

# Retrieve auth plugin class and instantiate it
try:
Expand Down
56 changes: 33 additions & 23 deletions src/py/flwr/common/auth_plugin/auth_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,31 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from pathlib import Path
from typing import Any, Optional, Union
from typing import Optional, Union

from flwr.proto.exec_pb2_grpc import ExecStub

from ..typing import UserAuthCredentials, UserAuthLoginDetails


class ExecAuthPlugin(ABC):
"""Abstract Flower Auth Plugin class for ExecServicer.
Parameters
----------
config : dict[str, Any]
The authentication configuration loaded from a YAML file.
user_auth_config_path : Path
Path to the YAML file containing the authentication configuration.
"""

@abstractmethod
def __init__(self, config: dict[str, Any]):
def __init__(
self,
user_auth_config_path: Path,
):
"""Abstract constructor."""

@abstractmethod
def get_login_details(self) -> dict[str, str]:
def get_login_details(self) -> Optional[UserAuthLoginDetails]:
"""Get the login details."""

@abstractmethod
Expand All @@ -47,7 +52,7 @@ def validate_tokens_in_metadata(
"""Validate authentication tokens in the provided metadata."""

@abstractmethod
def get_auth_tokens(self, auth_details: dict[str, str]) -> dict[str, str]:
def get_auth_tokens(self, device_code: str) -> Optional[UserAuthCredentials]:
"""Get authentication tokens."""

@abstractmethod
Expand All @@ -62,50 +67,55 @@ class CliAuthPlugin(ABC):
Parameters
----------
user_auth_config_path : Path
The path to the user's authentication configuration file.
credentials_path : Path
Path to the user's authentication credentials file.
"""

@staticmethod
@abstractmethod
def login(
login_details: dict[str, str],
login_details: UserAuthLoginDetails,
exec_stub: ExecStub,
) -> dict[str, Any]:
"""Authenticate the user with the SuperLink.
) -> UserAuthCredentials:
"""Authenticate the user and retrieve authentication credentials.
Parameters
----------
login_details : dict[str, str]
A dictionary containing the user's login details.
login_details : UserAuthLoginDetails
An object containing the user's login details.
exec_stub : ExecStub
An instance of `ExecStub` used for communication with the SuperLink.
A stub for executing RPC calls to the server.
Returns
-------
user_auth_config : dict[str, Any]
A dictionary containing the user's authentication configuration
in JSON format.
UserAuthCredentials
The authentication credentials obtained after login.
"""

@abstractmethod
def __init__(self, user_auth_config_path: Path):
def __init__(self, credentials_path: Path):
"""Abstract constructor."""

@abstractmethod
def store_tokens(self, user_auth_config: dict[str, Any]) -> None:
"""Store authentication tokens from the provided user_auth_config.
def store_tokens(self, credentials: UserAuthCredentials) -> None:
"""Store authentication tokens to the `credentials_path`.
The configuration, including tokens, will be saved as a JSON file
at `user_auth_config_path`.
The credentials, including tokens, will be saved as a JSON file
at `credentials_path`.
"""

@abstractmethod
def load_tokens(self) -> None:
"""Load authentication tokens from the user_auth_config_path."""
"""Load authentication tokens from the `credentials_path`."""

@abstractmethod
def write_tokens_to_metadata(
self, metadata: Sequence[tuple[str, Union[str, bytes]]]
) -> Sequence[tuple[str, Union[str, bytes]]]:
"""Write authentication tokens to the provided metadata."""

@abstractmethod
def read_tokens_from_metadata(
self, metadata: Sequence[tuple[str, Union[str, bytes]]]
) -> Optional[UserAuthCredentials]:
"""Read authentication tokens from the provided metadata."""
2 changes: 2 additions & 0 deletions src/py/flwr/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@
# Constants for user authentication
CREDENTIALS_DIR = ".credentials"
AUTH_TYPE = "auth_type"
ACCESS_TOKEN_KEY = "access_token"
REFRESH_TOKEN_KEY = "refresh_token"


class MessageType:
Expand Down
20 changes: 20 additions & 0 deletions src/py/flwr/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,23 @@ class InvalidRunStatusException(BaseException):
def __init__(self, message: str) -> None:
super().__init__(message)
self.message = message


# OIDC user authentication types
@dataclass
class UserAuthLoginDetails:
"""User authentication login details."""

auth_type: str
device_code: str
verification_uri_complete: str
expires_in: int
interval: int


@dataclass
class UserAuthCredentials:
"""User authentication tokens."""

access_token: str
refresh_token: str
36 changes: 12 additions & 24 deletions src/py/flwr/proto/exec_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit a8ea96e

Please sign in to comment.