diff --git a/src/proto/flwr/proto/exec.proto b/src/proto/flwr/proto/exec.proto index 851275c8ff0f..0804bd4bb5be 100644 --- a/src/proto/flwr/proto/exec.proto +++ b/src/proto/flwr/proto/exec.proto @@ -64,10 +64,19 @@ message ListRunsResponse { } message GetLoginDetailsRequest {} -message GetLoginDetailsResponse { map login_details = 1; } +message GetLoginDetailsResponse { + string auth_type = 1; + string device_code = 2; + string verification_uri_complete = 3; + int64 expires_in = 4; + int64 interval = 5; +} -message GetAuthTokensRequest { map auth_details = 1; } -message GetAuthTokensResponse { map auth_tokens = 1; } +message GetAuthTokensRequest { string device_code = 1; } +message GetAuthTokensResponse { + string access_token = 1; + string refresh_token = 2; +} message StopRunRequest { uint64 run_id = 1; } message StopRunResponse { bool success = 1; } diff --git a/src/py/flwr/cli/cli_user_auth_interceptor.py b/src/py/flwr/cli/cli_user_auth_interceptor.py index 7aa529bc3dd0..253255e71150 100644 --- a/src/py/flwr/cli/cli_user_auth_interceptor.py +++ b/src/py/flwr/cli/cli_user_auth_interceptor.py @@ -54,8 +54,12 @@ def _authenticated_call( response = continuation(details, request) if response.initial_metadata(): - retrieved_metadata = dict(response.initial_metadata()) - self.auth_plugin.store_tokens(retrieved_metadata) + credentials = self.auth_plugin.read_tokens_from_metadata( + response.initial_metadata() + ) + # The metadata contains tokens only if they have been refreshed + if credentials is not None: + self.auth_plugin.store_tokens(credentials) return response diff --git a/src/py/flwr/cli/login/login.py b/src/py/flwr/cli/login/login.py index f08835b360af..b5885fb8ec39 100644 --- a/src/py/flwr/cli/login/login.py +++ b/src/py/flwr/cli/login/login.py @@ -26,7 +26,7 @@ process_loaded_project_config, validate_federation_in_project_config, ) -from flwr.common.constant import AUTH_TYPE +from flwr.common.typing import UserAuthLoginDetails from flwr.proto.exec_pb2 import ( # pylint: disable=E0611 GetLoginDetailsRequest, GetLoginDetailsResponse, @@ -64,7 +64,7 @@ def login( # pylint: disable=R0914 login_response: GetLoginDetailsResponse = stub.GetLoginDetails(login_request) # Get the auth plugin - auth_type = login_response.login_details.get(AUTH_TYPE) + auth_type = login_response.auth_type auth_plugin = try_obtain_cli_auth_plugin(app, federation, auth_type) if auth_plugin is None: typer.secho( @@ -75,7 +75,14 @@ def login( # pylint: disable=R0914 raise typer.Exit(code=1) # Login - auth_config = auth_plugin.login(dict(login_response.login_details), stub) + details = UserAuthLoginDetails( + auth_type=login_response.auth_type, + device_code=login_response.device_code, + verification_uri_complete=login_response.verification_uri_complete, + expires_in=login_response.expires_in, + interval=login_response.interval, + ) + credentials = auth_plugin.login(details, stub) # Store the tokens - auth_plugin.store_tokens(auth_config) + auth_plugin.store_tokens(credentials) diff --git a/src/py/flwr/cli/utils.py b/src/py/flwr/cli/utils.py index e01a0439c9da..6c571f4e5cdb 100644 --- a/src/py/flwr/cli/utils.py +++ b/src/py/flwr/cli/utils.py @@ -223,19 +223,19 @@ def try_obtain_cli_auth_plugin( config_path = get_user_auth_config_path(root_dir, federation) # Load the config file if it exists - config: dict[str, Any] = {} + json_file: dict[str, Any] = {} if config_path.exists(): with config_path.open("r", encoding="utf-8") as file: - config = json.load(file) + json_file = json.load(file) # This is the case when the user auth is not enabled elif auth_type is None: return None # Get the auth type from the config if not provided if auth_type is None: - if AUTH_TYPE not in config: + if AUTH_TYPE not in json_file: return None - auth_type = config[AUTH_TYPE] + auth_type = json_file[AUTH_TYPE] # Retrieve auth plugin class and instantiate it try: diff --git a/src/py/flwr/common/auth_plugin/auth_plugin.py b/src/py/flwr/common/auth_plugin/auth_plugin.py index c9dc7a921623..17927b5cf23a 100644 --- a/src/py/flwr/common/auth_plugin/auth_plugin.py +++ b/src/py/flwr/common/auth_plugin/auth_plugin.py @@ -18,26 +18,31 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from pathlib import Path -from typing import Any, Optional, Union +from typing import Optional, Union from flwr.proto.exec_pb2_grpc import ExecStub +from ..typing import UserAuthCredentials, UserAuthLoginDetails + class ExecAuthPlugin(ABC): """Abstract Flower Auth Plugin class for ExecServicer. Parameters ---------- - config : dict[str, Any] - The authentication configuration loaded from a YAML file. + user_auth_config_path : Path + Path to the YAML file containing the authentication configuration. """ @abstractmethod - def __init__(self, config: dict[str, Any]): + def __init__( + self, + user_auth_config_path: Path, + ): """Abstract constructor.""" @abstractmethod - def get_login_details(self) -> dict[str, str]: + def get_login_details(self) -> Optional[UserAuthLoginDetails]: """Get the login details.""" @abstractmethod @@ -47,7 +52,7 @@ def validate_tokens_in_metadata( """Validate authentication tokens in the provided metadata.""" @abstractmethod - def get_auth_tokens(self, auth_details: dict[str, str]) -> dict[str, str]: + def get_auth_tokens(self, device_code: str) -> Optional[UserAuthCredentials]: """Get authentication tokens.""" @abstractmethod @@ -62,50 +67,55 @@ class CliAuthPlugin(ABC): Parameters ---------- - user_auth_config_path : Path - The path to the user's authentication configuration file. + credentials_path : Path + Path to the user's authentication credentials file. """ @staticmethod @abstractmethod def login( - login_details: dict[str, str], + login_details: UserAuthLoginDetails, exec_stub: ExecStub, - ) -> dict[str, Any]: - """Authenticate the user with the SuperLink. + ) -> UserAuthCredentials: + """Authenticate the user and retrieve authentication credentials. Parameters ---------- - login_details : dict[str, str] - A dictionary containing the user's login details. + login_details : UserAuthLoginDetails + An object containing the user's login details. exec_stub : ExecStub - An instance of `ExecStub` used for communication with the SuperLink. + A stub for executing RPC calls to the server. Returns ------- - user_auth_config : dict[str, Any] - A dictionary containing the user's authentication configuration - in JSON format. + UserAuthCredentials + The authentication credentials obtained after login. """ @abstractmethod - def __init__(self, user_auth_config_path: Path): + def __init__(self, credentials_path: Path): """Abstract constructor.""" @abstractmethod - def store_tokens(self, user_auth_config: dict[str, Any]) -> None: - """Store authentication tokens from the provided user_auth_config. + def store_tokens(self, credentials: UserAuthCredentials) -> None: + """Store authentication tokens to the `credentials_path`. - The configuration, including tokens, will be saved as a JSON file - at `user_auth_config_path`. + The credentials, including tokens, will be saved as a JSON file + at `credentials_path`. """ @abstractmethod def load_tokens(self) -> None: - """Load authentication tokens from the user_auth_config_path.""" + """Load authentication tokens from the `credentials_path`.""" @abstractmethod def write_tokens_to_metadata( self, metadata: Sequence[tuple[str, Union[str, bytes]]] ) -> Sequence[tuple[str, Union[str, bytes]]]: """Write authentication tokens to the provided metadata.""" + + @abstractmethod + def read_tokens_from_metadata( + self, metadata: Sequence[tuple[str, Union[str, bytes]]] + ) -> Optional[UserAuthCredentials]: + """Read authentication tokens from the provided metadata.""" diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index 9ea23e78c009..9968600856be 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -114,6 +114,8 @@ # Constants for user authentication CREDENTIALS_DIR = ".credentials" AUTH_TYPE = "auth_type" +ACCESS_TOKEN_KEY = "access_token" +REFRESH_TOKEN_KEY = "refresh_token" class MessageType: diff --git a/src/py/flwr/common/typing.py b/src/py/flwr/common/typing.py index d6b940ba75e6..94d7281f4267 100644 --- a/src/py/flwr/common/typing.py +++ b/src/py/flwr/common/typing.py @@ -266,3 +266,23 @@ class InvalidRunStatusException(BaseException): def __init__(self, message: str) -> None: super().__init__(message) self.message = message + + +# OIDC user authentication types +@dataclass +class UserAuthLoginDetails: + """User authentication login details.""" + + auth_type: str + device_code: str + verification_uri_complete: str + expires_in: int + interval: int + + +@dataclass +class UserAuthCredentials: + """User authentication tokens.""" + + access_token: str + refresh_token: str diff --git a/src/py/flwr/proto/exec_pb2.py b/src/py/flwr/proto/exec_pb2.py index d3c508a8c6d9..1e4bcc498d39 100644 --- a/src/py/flwr/proto/exec_pb2.py +++ b/src/py/flwr/proto/exec_pb2.py @@ -18,7 +18,7 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x14\x66lwr/proto/run.proto\"\xfb\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x35\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"2\n\x10StartRunResponse\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"1\n\x0fListRunsRequest\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"\x9d\x01\n\x10ListRunsResponse\x12;\n\x08run_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListRunsResponse.RunDictEntry\x12\x0b\n\x03now\x18\x02 \x01(\t\x1a?\n\x0cRunDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12\x1e\n\x05value\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run:\x02\x38\x01\"\x18\n\x16GetLoginDetailsRequest\"\x9c\x01\n\x17GetLoginDetailsResponse\x12L\n\rlogin_details\x18\x01 \x03(\x0b\x32\x35.flwr.proto.GetLoginDetailsResponse.LoginDetailsEntry\x1a\x33\n\x11LoginDetailsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\x93\x01\n\x14GetAuthTokensRequest\x12G\n\x0c\x61uth_details\x18\x01 \x03(\x0b\x32\x31.flwr.proto.GetAuthTokensRequest.AuthDetailsEntry\x1a\x32\n\x10\x41uthDetailsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\x92\x01\n\x15GetAuthTokensResponse\x12\x46\n\x0b\x61uth_tokens\x18\x01 \x03(\x0b\x32\x31.flwr.proto.GetAuthTokensResponse.AuthTokensEntry\x1a\x31\n\x0f\x41uthTokensEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\" \n\x0eStopRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\"\n\x0fStopRunResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x32\xe5\x03\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12\x44\n\x07StopRun\x12\x1a.flwr.proto.StopRunRequest\x1a\x1b.flwr.proto.StopRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12G\n\x08ListRuns\x12\x1b.flwr.proto.ListRunsRequest\x1a\x1c.flwr.proto.ListRunsResponse\"\x00\x12\\\n\x0fGetLoginDetails\x12\".flwr.proto.GetLoginDetailsRequest\x1a#.flwr.proto.GetLoginDetailsResponse\"\x00\x12V\n\rGetAuthTokens\x12 .flwr.proto.GetAuthTokensRequest\x1a!.flwr.proto.GetAuthTokensResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x14\x66lwr/proto/run.proto\"\xfb\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x35\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"2\n\x10StartRunResponse\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"1\n\x0fListRunsRequest\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"\x9d\x01\n\x10ListRunsResponse\x12;\n\x08run_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListRunsResponse.RunDictEntry\x12\x0b\n\x03now\x18\x02 \x01(\t\x1a?\n\x0cRunDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12\x1e\n\x05value\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run:\x02\x38\x01\"\x18\n\x16GetLoginDetailsRequest\"\x8a\x01\n\x17GetLoginDetailsResponse\x12\x11\n\tauth_type\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65vice_code\x18\x02 \x01(\t\x12!\n\x19verification_uri_complete\x18\x03 \x01(\t\x12\x12\n\nexpires_in\x18\x04 \x01(\x03\x12\x10\n\x08interval\x18\x05 \x01(\x03\"+\n\x14GetAuthTokensRequest\x12\x13\n\x0b\x64\x65vice_code\x18\x01 \x01(\t\"D\n\x15GetAuthTokensResponse\x12\x14\n\x0c\x61\x63\x63\x65ss_token\x18\x01 \x01(\t\x12\x15\n\rrefresh_token\x18\x02 \x01(\t\" \n\x0eStopRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\"\n\x0fStopRunResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x32\xe5\x03\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12\x44\n\x07StopRun\x12\x1a.flwr.proto.StopRunRequest\x1a\x1b.flwr.proto.StopRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12G\n\x08ListRuns\x12\x1b.flwr.proto.ListRunsRequest\x1a\x1c.flwr.proto.ListRunsResponse\"\x00\x12\\\n\x0fGetLoginDetails\x12\".flwr.proto.GetLoginDetailsRequest\x1a#.flwr.proto.GetLoginDetailsResponse\"\x00\x12V\n\rGetAuthTokens\x12 .flwr.proto.GetAuthTokensRequest\x1a!.flwr.proto.GetAuthTokensResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -29,12 +29,6 @@ _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' _globals['_LISTRUNSRESPONSE_RUNDICTENTRY']._options = None _globals['_LISTRUNSRESPONSE_RUNDICTENTRY']._serialized_options = b'8\001' - _globals['_GETLOGINDETAILSRESPONSE_LOGINDETAILSENTRY']._options = None - _globals['_GETLOGINDETAILSRESPONSE_LOGINDETAILSENTRY']._serialized_options = b'8\001' - _globals['_GETAUTHTOKENSREQUEST_AUTHDETAILSENTRY']._options = None - _globals['_GETAUTHTOKENSREQUEST_AUTHDETAILSENTRY']._serialized_options = b'8\001' - _globals['_GETAUTHTOKENSRESPONSE_AUTHTOKENSENTRY']._options = None - _globals['_GETAUTHTOKENSRESPONSE_AUTHTOKENSENTRY']._serialized_options = b'8\001' _globals['_STARTRUNREQUEST']._serialized_start=138 _globals['_STARTRUNREQUEST']._serialized_end=389 _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=316 @@ -54,21 +48,15 @@ _globals['_GETLOGINDETAILSREQUEST']._serialized_start=784 _globals['_GETLOGINDETAILSREQUEST']._serialized_end=808 _globals['_GETLOGINDETAILSRESPONSE']._serialized_start=811 - _globals['_GETLOGINDETAILSRESPONSE']._serialized_end=967 - _globals['_GETLOGINDETAILSRESPONSE_LOGINDETAILSENTRY']._serialized_start=916 - _globals['_GETLOGINDETAILSRESPONSE_LOGINDETAILSENTRY']._serialized_end=967 - _globals['_GETAUTHTOKENSREQUEST']._serialized_start=970 - _globals['_GETAUTHTOKENSREQUEST']._serialized_end=1117 - _globals['_GETAUTHTOKENSREQUEST_AUTHDETAILSENTRY']._serialized_start=1067 - _globals['_GETAUTHTOKENSREQUEST_AUTHDETAILSENTRY']._serialized_end=1117 - _globals['_GETAUTHTOKENSRESPONSE']._serialized_start=1120 - _globals['_GETAUTHTOKENSRESPONSE']._serialized_end=1266 - _globals['_GETAUTHTOKENSRESPONSE_AUTHTOKENSENTRY']._serialized_start=1217 - _globals['_GETAUTHTOKENSRESPONSE_AUTHTOKENSENTRY']._serialized_end=1266 - _globals['_STOPRUNREQUEST']._serialized_start=1268 - _globals['_STOPRUNREQUEST']._serialized_end=1300 - _globals['_STOPRUNRESPONSE']._serialized_start=1302 - _globals['_STOPRUNRESPONSE']._serialized_end=1336 - _globals['_EXEC']._serialized_start=1339 - _globals['_EXEC']._serialized_end=1824 + _globals['_GETLOGINDETAILSRESPONSE']._serialized_end=949 + _globals['_GETAUTHTOKENSREQUEST']._serialized_start=951 + _globals['_GETAUTHTOKENSREQUEST']._serialized_end=994 + _globals['_GETAUTHTOKENSRESPONSE']._serialized_start=996 + _globals['_GETAUTHTOKENSRESPONSE']._serialized_end=1064 + _globals['_STOPRUNREQUEST']._serialized_start=1066 + _globals['_STOPRUNREQUEST']._serialized_end=1098 + _globals['_STOPRUNRESPONSE']._serialized_start=1100 + _globals['_STOPRUNRESPONSE']._serialized_end=1134 + _globals['_EXEC']._serialized_start=1137 + _globals['_EXEC']._serialized_end=1622 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/exec_pb2.pyi b/src/py/flwr/proto/exec_pb2.pyi index d77aa26e1aa0..576f4322b316 100644 --- a/src/py/flwr/proto/exec_pb2.pyi +++ b/src/py/flwr/proto/exec_pb2.pyi @@ -143,77 +143,50 @@ global___GetLoginDetailsRequest = GetLoginDetailsRequest class GetLoginDetailsResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - class LoginDetailsEntry(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - key: typing.Text - value: typing.Text - def __init__(self, - *, - key: typing.Text = ..., - value: typing.Text = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... - - LOGIN_DETAILS_FIELD_NUMBER: builtins.int - @property - def login_details(self) -> google.protobuf.internal.containers.ScalarMap[typing.Text, typing.Text]: ... + AUTH_TYPE_FIELD_NUMBER: builtins.int + DEVICE_CODE_FIELD_NUMBER: builtins.int + VERIFICATION_URI_COMPLETE_FIELD_NUMBER: builtins.int + EXPIRES_IN_FIELD_NUMBER: builtins.int + INTERVAL_FIELD_NUMBER: builtins.int + auth_type: typing.Text + device_code: typing.Text + verification_uri_complete: typing.Text + expires_in: builtins.int + interval: builtins.int def __init__(self, *, - login_details: typing.Optional[typing.Mapping[typing.Text, typing.Text]] = ..., + auth_type: typing.Text = ..., + device_code: typing.Text = ..., + verification_uri_complete: typing.Text = ..., + expires_in: builtins.int = ..., + interval: builtins.int = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["login_details",b"login_details"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["auth_type",b"auth_type","device_code",b"device_code","expires_in",b"expires_in","interval",b"interval","verification_uri_complete",b"verification_uri_complete"]) -> None: ... global___GetLoginDetailsResponse = GetLoginDetailsResponse class GetAuthTokensRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - class AuthDetailsEntry(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - key: typing.Text - value: typing.Text - def __init__(self, - *, - key: typing.Text = ..., - value: typing.Text = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... - - AUTH_DETAILS_FIELD_NUMBER: builtins.int - @property - def auth_details(self) -> google.protobuf.internal.containers.ScalarMap[typing.Text, typing.Text]: ... + DEVICE_CODE_FIELD_NUMBER: builtins.int + device_code: typing.Text def __init__(self, *, - auth_details: typing.Optional[typing.Mapping[typing.Text, typing.Text]] = ..., + device_code: typing.Text = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["auth_details",b"auth_details"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["device_code",b"device_code"]) -> None: ... global___GetAuthTokensRequest = GetAuthTokensRequest class GetAuthTokensResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - class AuthTokensEntry(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - key: typing.Text - value: typing.Text - def __init__(self, - *, - key: typing.Text = ..., - value: typing.Text = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... - - AUTH_TOKENS_FIELD_NUMBER: builtins.int - @property - def auth_tokens(self) -> google.protobuf.internal.containers.ScalarMap[typing.Text, typing.Text]: ... + ACCESS_TOKEN_FIELD_NUMBER: builtins.int + REFRESH_TOKEN_FIELD_NUMBER: builtins.int + access_token: typing.Text + refresh_token: typing.Text def __init__(self, *, - auth_tokens: typing.Optional[typing.Mapping[typing.Text, typing.Text]] = ..., + access_token: typing.Text = ..., + refresh_token: typing.Text = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["auth_tokens",b"auth_tokens"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["access_token",b"access_token","refresh_token",b"refresh_token"]) -> None: ... global___GetAuthTokensResponse = GetAuthTokensResponse class StopRunRequest(google.protobuf.message.Message): diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index b5c7ae95e224..3d5acd9589d3 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -263,11 +263,10 @@ def run_superlink() -> None: # Obtain certificates certificates = try_obtain_server_certificates(args, args.fleet_api_type) - user_auth_config = _try_obtain_user_auth_config(args) auth_plugin: Optional[ExecAuthPlugin] = None - # user_auth_config is None only if the args.user_auth_config is not provided - if user_auth_config is not None: - auth_plugin = _try_obtain_exec_auth_plugin(user_auth_config) + # Load the auth plugin if the args.user_auth_config is provided + if cfg_path := getattr(args, "user_auth_config", None): + auth_plugin = _try_obtain_exec_auth_plugin(Path(cfg_path)) # Initialize StateFactory state_factory = LinkStateFactory(args.database) @@ -584,21 +583,20 @@ def _try_setup_node_authentication( ) -def _try_obtain_user_auth_config(args: argparse.Namespace) -> Optional[dict[str, Any]]: - if getattr(args, "user_auth_config", None) is not None: - with open(args.user_auth_config, encoding="utf-8") as file: - config: dict[str, Any] = yaml.safe_load(file) - return config - return None +def _try_obtain_exec_auth_plugin(config_path: Path) -> Optional[ExecAuthPlugin]: + # Load YAML file + with config_path.open("r", encoding="utf-8") as file: + config: dict[str, Any] = yaml.safe_load(file) - -def _try_obtain_exec_auth_plugin(config: dict[str, Any]) -> Optional[ExecAuthPlugin]: + # Load authentication configuration auth_config: dict[str, Any] = config.get("authentication", {}) auth_type: str = auth_config.get(AUTH_TYPE, "") + + # Load authentication plugin try: all_plugins: dict[str, type[ExecAuthPlugin]] = get_exec_auth_plugins() auth_plugin_class = all_plugins[auth_type] - return auth_plugin_class(config=auth_config) + return auth_plugin_class(user_auth_config_path=config_path) except KeyError: if auth_type != "": sys.exit( diff --git a/src/py/flwr/superexec/exec_servicer.py b/src/py/flwr/superexec/exec_servicer.py index c89f9ee47840..484333e12407 100644 --- a/src/py/flwr/superexec/exec_servicer.py +++ b/src/py/flwr/superexec/exec_servicer.py @@ -181,8 +181,20 @@ def GetLoginDetails( "ExecServicer initialized without user authentication", ) raise grpc.RpcError() # This line is unreachable + + # Get login details + details = self.auth_plugin.get_login_details() + + # Return empty response if details is None + if details is None: + return GetLoginDetailsResponse() + return GetLoginDetailsResponse( - login_details=self.auth_plugin.get_login_details() + auth_type=details.auth_type, + device_code=details.device_code, + verification_uri_complete=details.verification_uri_complete, + expires_in=details.expires_in, + interval=details.interval, ) def GetAuthTokens( @@ -196,8 +208,17 @@ def GetAuthTokens( "ExecServicer initialized without user authentication", ) raise grpc.RpcError() # This line is unreachable + + # Get auth tokens + credentials = self.auth_plugin.get_auth_tokens(request.device_code) + + # Return empty response if credentials is None + if credentials is None: + return GetAuthTokensResponse() + return GetAuthTokensResponse( - auth_tokens=self.auth_plugin.get_auth_tokens(dict(request.auth_details)) + access_token=credentials.access_token, + refresh_token=credentials.refresh_token, )