Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(framework) Enforce strong typing for user auth code for flwr CLI #4703

Merged
merged 12 commits into from
Jan 8, 2025
15 changes: 12 additions & 3 deletions src/proto/flwr/proto/exec.proto
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,19 @@ message ListRunsResponse {
}

message GetLoginDetailsRequest {}
message GetLoginDetailsResponse { map<string, string> login_details = 1; }
message GetLoginDetailsResponse {
string auth_type = 1;
string device_code = 2;
string verification_uri_complete = 3;
int64 expires_in = 4;
int64 interval = 5;
}

message GetAuthTokensRequest { map<string, string> auth_details = 1; }
message GetAuthTokensResponse { map<string, string> auth_tokens = 1; }
message GetAuthTokensRequest { string device_code = 1; }
message GetAuthTokensResponse {
string access_token = 1;
string refresh_token = 2;
}

message StopRunRequest { uint64 run_id = 1; }
message StopRunResponse { bool success = 1; }
8 changes: 6 additions & 2 deletions src/py/flwr/cli/cli_user_auth_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,12 @@ def _authenticated_call(

response = continuation(details, request)
if response.initial_metadata():
retrieved_metadata = dict(response.initial_metadata())
self.auth_plugin.store_tokens(retrieved_metadata)
credentials = self.auth_plugin.read_tokens_from_metadata(
response.initial_metadata()
)
# The metadata contains tokens only if they have been refreshed
if credentials is not None:
self.auth_plugin.store_tokens(credentials)

return response

Expand Down
15 changes: 11 additions & 4 deletions src/py/flwr/cli/login/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
process_loaded_project_config,
validate_federation_in_project_config,
)
from flwr.common.constant import AUTH_TYPE
from flwr.common.typing import UserAuthLoginDetails
from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
GetLoginDetailsRequest,
GetLoginDetailsResponse,
Expand Down Expand Up @@ -64,7 +64,7 @@ def login( # pylint: disable=R0914
login_response: GetLoginDetailsResponse = stub.GetLoginDetails(login_request)

# Get the auth plugin
auth_type = login_response.login_details.get(AUTH_TYPE)
auth_type = login_response.auth_type
auth_plugin = try_obtain_cli_auth_plugin(app, federation, auth_type)
if auth_plugin is None:
typer.secho(
Expand All @@ -75,7 +75,14 @@ def login( # pylint: disable=R0914
raise typer.Exit(code=1)

# Login
auth_config = auth_plugin.login(dict(login_response.login_details), stub)
details = UserAuthLoginDetails(
auth_type=login_response.auth_type,
device_code=login_response.device_code,
verification_uri_complete=login_response.verification_uri_complete,
expires_in=login_response.expires_in,
interval=login_response.interval,
)
credentials = auth_plugin.login(details, stub)

# Store the tokens
auth_plugin.store_tokens(auth_config)
auth_plugin.store_tokens(credentials)
8 changes: 4 additions & 4 deletions src/py/flwr/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,19 +223,19 @@ def try_obtain_cli_auth_plugin(
config_path = get_user_auth_config_path(root_dir, federation)

# Load the config file if it exists
config: dict[str, Any] = {}
json_file: dict[str, Any] = {}
if config_path.exists():
with config_path.open("r", encoding="utf-8") as file:
config = json.load(file)
json_file = json.load(file)
# This is the case when the user auth is not enabled
elif auth_type is None:
return None

# Get the auth type from the config if not provided
if auth_type is None:
if AUTH_TYPE not in config:
if AUTH_TYPE not in json_file:
return None
auth_type = config[AUTH_TYPE]
auth_type = json_file[AUTH_TYPE]

# Retrieve auth plugin class and instantiate it
try:
Expand Down
56 changes: 33 additions & 23 deletions src/py/flwr/common/auth_plugin/auth_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,31 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from pathlib import Path
from typing import Any, Optional, Union
from typing import Optional, Union

from flwr.proto.exec_pb2_grpc import ExecStub

from ..typing import UserAuthCredentials, UserAuthLoginDetails


class ExecAuthPlugin(ABC):
"""Abstract Flower Auth Plugin class for ExecServicer.

Parameters
----------
config : dict[str, Any]
The authentication configuration loaded from a YAML file.
user_auth_config_path : Path
Path to the YAML file containing the authentication configuration.
"""

@abstractmethod
def __init__(self, config: dict[str, Any]):
def __init__(
self,
user_auth_config_path: Path,
):
"""Abstract constructor."""

@abstractmethod
def get_login_details(self) -> dict[str, str]:
def get_login_details(self) -> Optional[UserAuthLoginDetails]:
"""Get the login details."""

@abstractmethod
Expand All @@ -47,7 +52,7 @@ def validate_tokens_in_metadata(
"""Validate authentication tokens in the provided metadata."""

@abstractmethod
def get_auth_tokens(self, auth_details: dict[str, str]) -> dict[str, str]:
def get_auth_tokens(self, device_code: str) -> Optional[UserAuthCredentials]:
"""Get authentication tokens."""

@abstractmethod
Expand All @@ -62,50 +67,55 @@ class CliAuthPlugin(ABC):

Parameters
----------
user_auth_config_path : Path
The path to the user's authentication configuration file.
credentials_path : Path
Path to the user's authentication credentials file.
"""

@staticmethod
@abstractmethod
def login(
login_details: dict[str, str],
login_details: UserAuthLoginDetails,
exec_stub: ExecStub,
) -> dict[str, Any]:
"""Authenticate the user with the SuperLink.
) -> UserAuthCredentials:
"""Authenticate the user and retrieve authentication credentials.

Parameters
----------
login_details : dict[str, str]
A dictionary containing the user's login details.
login_details : UserAuthLoginDetails
An object containing the user's login details.
exec_stub : ExecStub
An instance of `ExecStub` used for communication with the SuperLink.
A stub for executing RPC calls to the server.

Returns
-------
user_auth_config : dict[str, Any]
A dictionary containing the user's authentication configuration
in JSON format.
UserAuthCredentials
The authentication credentials obtained after login.
"""

@abstractmethod
def __init__(self, user_auth_config_path: Path):
def __init__(self, credentials_path: Path):
"""Abstract constructor."""

@abstractmethod
def store_tokens(self, user_auth_config: dict[str, Any]) -> None:
"""Store authentication tokens from the provided user_auth_config.
def store_tokens(self, credentials: UserAuthCredentials) -> None:
"""Store authentication tokens to the `credentials_path`.

The configuration, including tokens, will be saved as a JSON file
at `user_auth_config_path`.
The credentials, including tokens, will be saved as a JSON file
at `credentials_path`.
"""

@abstractmethod
def load_tokens(self) -> None:
"""Load authentication tokens from the user_auth_config_path."""
"""Load authentication tokens from the `credentials_path`."""

@abstractmethod
def write_tokens_to_metadata(
self, metadata: Sequence[tuple[str, Union[str, bytes]]]
) -> Sequence[tuple[str, Union[str, bytes]]]:
"""Write authentication tokens to the provided metadata."""

@abstractmethod
def read_tokens_from_metadata(
self, metadata: Sequence[tuple[str, Union[str, bytes]]]
) -> Optional[UserAuthCredentials]:
"""Read authentication tokens from the provided metadata."""
2 changes: 2 additions & 0 deletions src/py/flwr/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@
# Constants for user authentication
CREDENTIALS_DIR = ".credentials"
AUTH_TYPE = "auth_type"
ACCESS_TOKEN_KEY = "access_token"
REFRESH_TOKEN_KEY = "refresh_token"


class MessageType:
Expand Down
40 changes: 40 additions & 0 deletions src/py/flwr/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,43 @@ class InvalidRunStatusException(BaseException):
def __init__(self, message: str) -> None:
super().__init__(message)
self.message = message


# OIDC user authentication types
@dataclass
class UserAuthLoginDetails:
"""User authentication login details."""

auth_type: str
device_code: str
verification_uri_complete: str
expires_in: int
interval: int


@dataclass
class UserAuthCredentials:
"""User authentication tokens."""

access_token: str
refresh_token: str


# OIDC user authentication types
@dataclass
class UserAuthLoginDetails:
"""User authentication login details."""

auth_type: str
device_code: str
verification_uri_complete: str
expires_in: int
interval: int


@dataclass
class UserAuthCredentials:
"""User authentication tokens."""

access_token: str
refresh_token: str
36 changes: 12 additions & 24 deletions src/py/flwr/proto/exec_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading