From 5ba79ff1974c178419adc10c28681d0fcb1c3712 Mon Sep 17 00:00:00 2001 From: Doug Date: Thu, 21 Apr 2022 16:58:07 +0100 Subject: [PATCH 1/5] Add AuthenticationService and RegistrationWizard. --- Riot/Categories/MXHTTPClient+Async.swift | 44 +++ Riot/Categories/MXRestClient+Async.swift | 142 ++++++++ .../MatrixKit/Models/Account/MXKAccount.h | 2 +- .../MatrixKit/Models/Account/MXKAccount.m | 2 +- .../AuthenticationCoordinatorState.swift | 41 +++ .../Common/AuthenticationModels.swift | 47 +++ .../Common/HomeserverAddress.swift | 29 ++ .../MatrixSDK/AuthenticationPendingData.swift | 43 +++ .../MatrixSDK/AuthenticationService.swift | 334 ++++++++++++++++++ .../Service/MatrixSDK/LoginModels.swift | 58 +++ .../Service/MatrixSDK/LoginWizard.swift | 21 ++ .../MatrixSDK/RegistrationModels.swift | 197 +++++++++++ .../MatrixSDK/RegistrationWizard.swift | 245 +++++++++++++ .../Service/MatrixSDK/SessionCreator.swift | 38 ++ .../Service/MatrixSDK/ThreePIDModels.swift | 122 +++++++ .../Common/ErrorHandling/AlertInfo.swift | 14 +- 16 files changed, 1374 insertions(+), 5 deletions(-) create mode 100644 Riot/Categories/MXHTTPClient+Async.swift create mode 100644 Riot/Categories/MXRestClient+Async.swift create mode 100644 Riot/Modules/Onboarding/AuthenticationCoordinatorState.swift create mode 100644 RiotSwiftUI/Modules/Authentication/Common/AuthenticationModels.swift create mode 100644 RiotSwiftUI/Modules/Authentication/Common/HomeserverAddress.swift create mode 100644 RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationPendingData.swift create mode 100644 RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationService.swift create mode 100644 RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/LoginModels.swift create mode 100644 RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/LoginWizard.swift create mode 100644 RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/RegistrationModels.swift create mode 100644 RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/RegistrationWizard.swift create mode 100644 RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/SessionCreator.swift create mode 100644 RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/ThreePIDModels.swift diff --git a/Riot/Categories/MXHTTPClient+Async.swift b/Riot/Categories/MXHTTPClient+Async.swift new file mode 100644 index 0000000000..93ac419f93 --- /dev/null +++ b/Riot/Categories/MXHTTPClient+Async.swift @@ -0,0 +1,44 @@ +// +// Copyright 2022 New Vector Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +@available(iOS 13.0, *) +extension MXHTTPClient { + /// Errors thrown by the async extensions to `MXHTTPClient.` + enum ClientError: Error { + /// An unexpected response was received. + case invalidResponse + /// The error that occurred was missing from the closure. + case unknownError + } + + /// An async version of `request(withMethod:path:parameters:success:failure:)`. + func request(withMethod method: String, path: String, parameters: [AnyHashable: Any]) async throws -> [AnyHashable: Any] { + try await withCheckedThrowingContinuation { continuation in + request(withMethod: method, path: path, parameters: parameters) { jsonDictionary in + guard let jsonDictionary = jsonDictionary else { + continuation.resume(with: .failure(ClientError.invalidResponse)) + return + } + + continuation.resume(with: .success(jsonDictionary)) + } failure: { error in + continuation.resume(with: .failure(error ?? ClientError.unknownError)) + } + } + } +} diff --git a/Riot/Categories/MXRestClient+Async.swift b/Riot/Categories/MXRestClient+Async.swift new file mode 100644 index 0000000000..e58ff06af6 --- /dev/null +++ b/Riot/Categories/MXRestClient+Async.swift @@ -0,0 +1,142 @@ +// +// Copyright 2022 New Vector Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +@available(iOS 13.0, *) +extension MXRestClient { + /// Errors thrown by the async extensions to `MXRestClient.` + enum ClientError: Error { + /// An unexpected response was received. + case invalidResponse + /// The error that occurred was missing from the closure. + case unknownError + /// An error occurred whilst decoding the received JSON. + case decodingError + } + + /// An async version of `wellKnow(_:failure:)`. + func wellKnown() async throws -> MXWellKnown { + try await withCheckedThrowingContinuation { continuation in + wellKnow { wellKnown in + guard let wellKnown = wellKnown else { + continuation.resume(with: .failure(ClientError.invalidResponse)) + return + } + + continuation.resume(with: .success(wellKnown)) + } failure: { error in + continuation.resume(with: .failure(error ?? ClientError.unknownError)) + } + } + } + + /// An async version of `getRegisterSession(completion:)`. + func getRegisterSession() async throws -> MXAuthenticationSession { + try await withCheckedThrowingContinuation { continuation in + getRegisterSession { response in + guard let session = response.value else { + continuation.resume(with: .failure(response.error ?? ClientError.unknownError)) + return + } + + continuation.resume(with: .success(session)) + } + } + } + + /// An async version of `getLoginSession(completion:)`. + func getLoginSession() async throws -> MXAuthenticationSession { + try await withCheckedThrowingContinuation { continuation in + getLoginSession { response in + guard let session = response.value else { + continuation.resume(with: .failure(response.error ?? ClientError.unknownError)) + return + } + + continuation.resume(with: .success(session)) + } + } + } + + /// An async version of `isUsernameAvailable(_:completion:)`. + func isUsernameAvailable(_ username: String) async throws -> Bool { + try await withCheckedThrowingContinuation { continuation in + isUsernameAvailable(username) { response in + guard let availability = response.value else { + continuation.resume(with: .failure(response.error ?? ClientError.unknownError)) + return + } + + continuation.resume(with: .success(availability.available)) + } + } + } + + /// An async version of `register(parameters:completion:)`. + func register(parameters: [String: Any]) async throws -> MXLoginResponse { + try await withCheckedThrowingContinuation { continuation in + register(parameters: parameters) { response in + guard let jsonDictionary = response.value else { + continuation.resume(with: .failure(response.error ?? ClientError.unknownError)) + return + } + + guard let loginResponse = MXLoginResponse(fromJSON: jsonDictionary) else { + continuation.resume(with: .failure(ClientError.decodingError)) + return + } + + continuation.resume(with: .success(loginResponse)) + } + } + } + + /// An async version of both `requestToken(forEmail:isDuringRegistration:clientSecret:sendAttempt:nextLink:success:failure:)` and + /// `requestToken(forPhoneNumber:isDuringRegistration:countryCode:clientSecret:sendAttempt:nextLink:success:failure:)` depending + /// on the kind of third party ID is supplied to the `threePID` parameter. + func requestTokenDuringRegistration(for threePID: RegisterThreePID, clientSecret: String, sendAttempt: UInt) async throws -> RegistrationThreePIDTokenResponse { + try await withCheckedThrowingContinuation { continuation in + switch threePID { + case .email(let email): + requestToken(forEmail: email, isDuringRegistration: true, clientSecret: clientSecret, sendAttempt: sendAttempt, nextLink: nil) { sessionID in + guard let sessionID = sessionID else { + continuation.resume(with: .failure(ClientError.invalidResponse)) + return + } + + let response = RegistrationThreePIDTokenResponse(sessionID: sessionID) + continuation.resume(with: .success(response)) + } failure: { error in + continuation.resume(with: .failure(error ?? ClientError.unknownError)) + } + + case .msisdn(let msisdn, let countryCode): + requestToken(forPhoneNumber: msisdn, isDuringRegistration: true, countryCode: countryCode, clientSecret: clientSecret, sendAttempt: sendAttempt, nextLink: nil) { sessionID, msisdn, submitURL in + guard let sessionID = sessionID else { + continuation.resume(with: .failure(ClientError.invalidResponse)) + return + } + + let response = RegistrationThreePIDTokenResponse(sessionID: sessionID, submitURL: submitURL, msisdn: msisdn) + continuation.resume(with: .success(response)) + } failure: { error in + continuation.resume(with: .failure(error ?? ClientError.unknownError)) + } + } + } + } +} diff --git a/Riot/Modules/MatrixKit/Models/Account/MXKAccount.h b/Riot/Modules/MatrixKit/Models/Account/MXKAccount.h index e2c56d63d0..785972fa5a 100644 --- a/Riot/Modules/MatrixKit/Models/Account/MXKAccount.h +++ b/Riot/Modules/MatrixKit/Models/Account/MXKAccount.h @@ -175,7 +175,7 @@ typedef BOOL (^MXKAccountOnCertificateChange)(MXKAccount *mxAccount, NSData *cer @param credentials user's credentials */ -- (instancetype)initWithCredentials:(MXCredentials*)credentials; +- (nonnull instancetype)initWithCredentials:(MXCredentials*)credentials; /** Create a matrix session based on the provided store. diff --git a/Riot/Modules/MatrixKit/Models/Account/MXKAccount.m b/Riot/Modules/MatrixKit/Models/Account/MXKAccount.m index 4de8b53a6f..9c35500d36 100644 --- a/Riot/Modules/MatrixKit/Models/Account/MXKAccount.m +++ b/Riot/Modules/MatrixKit/Models/Account/MXKAccount.m @@ -135,7 +135,7 @@ + (UIColor*)presenceColor:(MXPresence)presence } } -- (instancetype)initWithCredentials:(MXCredentials*)credentials +- (nonnull instancetype)initWithCredentials:(MXCredentials*)credentials { if (self = [super init]) { diff --git a/Riot/Modules/Onboarding/AuthenticationCoordinatorState.swift b/Riot/Modules/Onboarding/AuthenticationCoordinatorState.swift new file mode 100644 index 0000000000..9eebfda03f --- /dev/null +++ b/Riot/Modules/Onboarding/AuthenticationCoordinatorState.swift @@ -0,0 +1,41 @@ +// +// Copyright 2022 New Vector Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +import MatrixSDK + +@available(iOS 14.0, *) +struct AuthenticationCoordinatorState { + // MARK: User choices + // var serverType: ServerType = .unknown + // var signMode: SignMode = .unknown + var resetPasswordEmail: String? + + /// The homeserver address as returned by the server. + var homeserverAddress: String? + /// The homeserver address as input by the user (it can differ to the well-known request). + var homeserverAddressFromUser: String? + + /// For SSO session recovery + var deviceId: String? + + // MARK: Network result + var loginMode: LoginMode = .unknown + /// Supported types for the login. + var loginModeSupportedTypes = [MXLoginFlow]() + var knownCustomHomeServersUrls = [String]() + var isForceLoginFallbackEnabled = false +} diff --git a/RiotSwiftUI/Modules/Authentication/Common/AuthenticationModels.swift b/RiotSwiftUI/Modules/Authentication/Common/AuthenticationModels.swift new file mode 100644 index 0000000000..5435051c24 --- /dev/null +++ b/RiotSwiftUI/Modules/Authentication/Common/AuthenticationModels.swift @@ -0,0 +1,47 @@ +// +// Copyright 2022 New Vector Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +/// Errors that can be thrown from `AuthenticationService`, `RegistrationWizard` and `LoginWizard`. +enum AuthenticationError: String, Error { + // MARK: AuthenticationService + /// A failure to convert a struct into a dictionary. + case dictionaryError + case invalidHomeserver + case loginFlowNotCalled + case missingRegistrationWizard + case missingMXRestClient + + // MARK: RegistrationWizard + case createAccountNotCalled + case noPendingThreePID + case missingThreePIDURL + case threePIDValidationFailure + case threePIDClientFailure +} + +/// Represents an SSO Identity Provider as provided in a login flow. +struct SSOIdentityProvider: Identifiable { + /// The identifier field (id field in JSON) is the Identity Provider identifier used for the SSO Web page redirection `/login/sso/redirect/{idp_id}`. + let id: String + /// The name field is a human readable string intended to be printed by the client. + let name: String + /// The brand field is optional. It allows the client to style the login button to suit a particular brand. + let brand: String? + /// The icon field is an optional field that points to an icon representing the identity provider. If present then it must be an HTTPS URL to an image resource. + let iconURL: String? +} diff --git a/RiotSwiftUI/Modules/Authentication/Common/HomeserverAddress.swift b/RiotSwiftUI/Modules/Authentication/Common/HomeserverAddress.swift new file mode 100644 index 0000000000..c65de672a3 --- /dev/null +++ b/RiotSwiftUI/Modules/Authentication/Common/HomeserverAddress.swift @@ -0,0 +1,29 @@ +// +// Copyright 2022 New Vector Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +class HomeserverAddress { + /// Ensures the address contains a scheme, otherwise makes it `https`. + static func sanitize(_ address: String) -> String { + !address.contains("://") ? "https://\(address.lowercased())" : address.lowercased() + } + + /// Strips the `https://` away from the address (but leaves `http://`) for display in labels. + static func displayable(_ address: String) -> String { + address.replacingOccurrences(of: "https://", with: "") + } +} diff --git a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationPendingData.swift b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationPendingData.swift new file mode 100644 index 0000000000..2aa4bc4992 --- /dev/null +++ b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationPendingData.swift @@ -0,0 +1,43 @@ +// +// Copyright 2022 New Vector Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +/// This class holds all pending data when creating a session, either by login or by register +class AuthenticationPendingData { + let homeserverAddress: String + + // MARK: - Common + + var clientSecret = UUID().uuidString + var sendAttempt: UInt = 0 + + // MARK: - For login + + // var resetPasswordData: ResetPasswordData? + + // MARK: - For registration + + var currentSession: String? + var isRegistrationStarted = false + var currentThreePIDData: ThreePIDData? + + // MARK: - Setup + + init(homeserverAddress: String) { + self.homeserverAddress = homeserverAddress + } +} diff --git a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationService.swift b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationService.swift new file mode 100644 index 0000000000..53692461f3 --- /dev/null +++ b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationService.swift @@ -0,0 +1,334 @@ +// +// Copyright 2021 New Vector Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +@available(iOS 14.0, *) +protocol AuthenticationServiceDelegate: AnyObject { + func authenticationServiceDidUpdateRegistrationParameters(_ authenticationService: AuthenticationService) +} + +@available(iOS 14.0, *) +class AuthenticationService: NSObject { + + /// The shared service object. + static let shared = AuthenticationService() + + // MARK: - Properties + + // MARK: Private + + /// The rest client used to make authentication requests. + private var client: MXRestClient + /// Pending data collected as the authentication flow progresses. + private var pendingData: AuthenticationPendingData? + /// The current registration wizard or `nil` if `registrationWizard()` hasn't been called. + private var currentRegistrationWizard: RegistrationWizard? + /// The current login wizard or `nil` if `loginWizard()` hasn't been called. + private var currentLoginWizard: LoginWizard? + /// The object used to create a new `MXSession` when authentication has completed. + private var sessionCreator = SessionCreator() + + // MARK: Public + + /// The address of the homeserver that the service is using. + var homeserverAddress: String { + state.homeserverAddress ?? RiotSettings.shared.homeserverUrlString + } + + + // MARK: Android OnboardingViewModel + /// The current state of the authentication flow. + private var state = AuthenticationCoordinatorState() + /// The currently executing async task. + private var currentTask: Task? { + willSet { + currentTask?.cancel() + } + } + + + // MARK: - Setup + + override init() { + guard let homeserverURL = URL(string: RiotSettings.shared.homeserverUrlString) else { + fatalError("[AuthenticationService]: Failed to create URL from default homeserver URL string.") + } + + client = MXRestClient(homeServer: homeserverURL, unrecognizedCertificateHandler: nil) + + super.init() + } + + // MARK: - Android OnboardingViewModel + + func loginFlow(homeserverAddress: String) async { + currentTask = Task { + cancelPendingLoginOrRegistration() + + do { + let data = try await loginFlow(for: homeserverAddress) + + guard !Task.isCancelled else { return } + + // Valid Homeserver, add it to the history. + // Note: we add what the user has input, as the data can contain a different value. + RiotSettings.shared.homeserverUrlString = homeserverAddress + + let loginMode: LoginMode + + if data.supportedLoginTypes.contains(where: { $0.type == kMXLoginFlowTypeSSO }), + data.supportedLoginTypes.contains(where: { $0.type == kMXLoginFlowTypePassword }) { + loginMode = .ssoAndPassword(ssoIdentityProviders: data.ssoIdentityProviders) + } else if data.supportedLoginTypes.contains(where: { $0.type == kMXLoginFlowTypeSSO }) { + loginMode = .sso(ssoIdentityProviders: data.ssoIdentityProviders) + } else if data.supportedLoginTypes.contains(where: { $0.type == kMXLoginFlowTypePassword }) { + loginMode = .password + } else { + loginMode = .unsupported + } + + state.homeserverAddressFromUser = homeserverAddress + state.homeserverAddress = data.homeserverAddress + state.loginMode = loginMode + state.loginModeSupportedTypes = data.supportedLoginTypes + } catch { + #warning("Show an error message and/or handle the error?") + return + } + } + } + + func refreshServer(homeserverAddress: String) async throws -> (LoginFlowResult, RegistrationResult) { + let loginFlows = try await loginFlow(for: homeserverAddress) + let wizard = try registrationWizard() + let registrationFlow = try await wizard.registrationFlow() + + state.homeserverAddress = homeserverAddress + + return (loginFlows, registrationFlow) + } + + // MARK: - Public + + /// Whether authentication is needed by checking for any accounts. + /// - Returns: `true` there are no accounts or if there is an inactive account that has had a soft logout. + var needsAuthentication: Bool { + MXKAccountManager.shared().accounts.isEmpty || softLogoutCredentials != nil + } + + /// Credentials to be used when authenticating after soft logout, otherwise `nil`. + var softLogoutCredentials: MXCredentials? { + guard MXKAccountManager.shared().activeAccounts.isEmpty else { return nil } + for account in MXKAccountManager.shared().accounts { + if account.isSoftLogout { + return account.mxCredentials + } + } + + return nil + } + + /// Get the last authenticated [Session], if there is an active session. + /// - Returns: The last active session if any, or `nil` + var lastAuthenticatedSession: MXSession? { + MXKAccountManager.shared().activeAccounts?.first?.mxSession + } + + enum AuthenticationMode { + case login + case registration + } + + /// Request the supported login flows for this homeserver. + /// This is the first method to call to be able to get a wizard to login or to create an account + /// - Parameter homeserverAddress: The homeserver string entered by the user. + func loginFlow(for homeserverAddress: String) async throws -> LoginFlowResult { + pendingData = nil + + let homeserverAddress = HomeserverAddress.sanitize(homeserverAddress) + + guard var homeserverURL = URL(string: homeserverAddress) else { + throw AuthenticationError.invalidHomeserver + } + + let pendingData = AuthenticationPendingData(homeserverAddress: homeserverAddress) + + if let wellKnown = try? await wellKnown(for: homeserverURL), + let baseURL = URL(string: wellKnown.homeServer.baseUrl) { + homeserverURL = baseURL + } + + #warning("Add an unrecognized certificate handler.") + let client = MXRestClient(homeServer: homeserverURL, unrecognizedCertificateHandler: nil) + + let loginFlow = try await getLoginFlowResult(client: client) + + self.client = client + self.pendingData = pendingData + + return loginFlow + } + + /// Request the supported login flows for the corresponding session. + /// This method is used to get the flows for a server after a soft-logout. + /// - Parameter session: The MXSession where a soft-logout has occurred. + func loginFlow(for session: MXSession) async throws -> LoginFlowResult { + pendingData = nil + + guard let client = session.matrixRestClient else { throw AuthenticationError.missingMXRestClient } + let pendingData = AuthenticationPendingData(homeserverAddress: client.homeserver) + + let loginFlow = try await getLoginFlowResult(client: session.matrixRestClient) + + self.client = client + self.pendingData = pendingData + + return loginFlow + } + + /// Get a SSO url + func getSSOURL(redirectUrl: String, deviceId: String?, providerId: String?) -> String? { + fatalError("Not implemented.") + } + + /// Get the sign in or sign up fallback URL + func fallbackURL(for authenticationMode: AuthenticationMode) -> URL { + switch authenticationMode { + case .login: + return client.loginFallbackURL + case .registration: + return client.registerFallbackURL + } + } + + /// Return a LoginWizard, to login to the homeserver. The login flow has to be retrieved first. + /// + /// See ``LoginWizard`` for more details + func loginWizard() throws -> LoginWizard { + if let currentLoginWizard = currentLoginWizard { + return currentLoginWizard + } + + guard let pendingData = pendingData else { + throw AuthenticationError.loginFlowNotCalled + } + + let wizard = LoginWizard() + return wizard + } + + /// Return a RegistrationWizard, to create a matrix account on the homeserver. The login flow has to be retrieved first. + /// + /// See ``RegistrationWizard`` for more details. + func registrationWizard() throws -> RegistrationWizard { + if let currentRegistrationWizard = currentRegistrationWizard { + return currentRegistrationWizard + } + + guard let pendingData = pendingData else { + throw AuthenticationError.loginFlowNotCalled + } + + + let wizard = RegistrationWizard(client: client, pendingData: pendingData) + currentRegistrationWizard = wizard + return wizard + } + + /// True when login and password has been sent with success to the homeserver + var isRegistrationStarted: Bool { + currentRegistrationWizard?.isRegistrationStarted ?? false + } + + /// Cancel pending login or pending registration + func cancelPendingLoginOrRegistration() { + currentTask?.cancel() + + currentLoginWizard = nil + currentRegistrationWizard = nil + + // Keep only the homesever config + guard let pendingData = pendingData else { + // Should not happen + return + } + + self.pendingData = AuthenticationPendingData(homeserverAddress: pendingData.homeserverAddress) + } + + /// Reset all pending settings, including current HomeServerConnectionConfig + func reset() { + pendingData = nil + currentRegistrationWizard = nil + currentLoginWizard = nil + } + + /// Create a session after a SSO successful login + func makeSessionFromSSO(credentials: MXCredentials) -> MXSession { + sessionCreator.createSession(credentials: credentials, client: client) + } + +// /// Perform a well-known request, using the domain from the matrixId +// func getWellKnownData(matrixId: String, +// homeServerConnectionConfig: HomeServerConnectionConfig?) async -> WellknownResult { +// +// } +// +// /// Authenticate with a matrixId and a password +// /// Usually call this after a successful call to getWellKnownData() +// /// - Parameter homeServerConnectionConfig the information about the homeserver and other configuration +// /// - Parameter matrixId the matrixId of the user +// /// - Parameter password the password of the account +// /// - Parameter initialDeviceName the initial device name +// /// - Parameter deviceId the device id, optional. If not provided or null, the server will generate one. +// func directAuthentication(homeServerConnectionConfig: HomeServerConnectionConfig, +// matrixId: String, +// password: String, +// initialDeviceName: String, +// deviceId: String? = nil) async -> MXSession { +// +// } + + // MARK: - Private + + private func getLoginFlowResult(client: MXRestClient/*, versions: Versions*/) async throws -> LoginFlowResult { + // Get the login flow + let loginFlowResponse = try await client.getLoginSession() + + let identityProviders = loginFlowResponse.flows?.compactMap { $0 as? MXLoginSSOFlow }.first?.identityProviders ?? [] + return LoginFlowResult(supportedLoginTypes: loginFlowResponse.flows?.compactMap { $0 } ?? [], + ssoIdentityProviders: identityProviders.sorted { $0.name < $1.name }.map { $0.ssoIdentityProvider }, + homeserverAddress: client.homeserver) + } + + /// Perform a well-known request on the specified homeserver URL. + private func wellKnown(for homeserverURL: URL) async throws -> MXWellKnown { + let wellKnownClient = MXRestClient(homeServer: homeserverURL, unrecognizedCertificateHandler: nil) + + // The .well-known/matrix/client API is often just a static file returned with no content type. + // Make our HTTP client compatible with this behaviour + wellKnownClient.acceptableContentTypes = nil + + return try await wellKnownClient.wellKnown() + } +} + +extension MXLoginSSOIdentityProvider { + var ssoIdentityProvider: SSOIdentityProvider { + SSOIdentityProvider(id: identifier, name: name, brand: brand, iconURL: icon) + } +} diff --git a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/LoginModels.swift b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/LoginModels.swift new file mode 100644 index 0000000000..22108de730 --- /dev/null +++ b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/LoginModels.swift @@ -0,0 +1,58 @@ +// +// Copyright 2022 New Vector Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +struct LoginFlowResult { + let supportedLoginTypes: [MXLoginFlow] + let ssoIdentityProviders: [SSOIdentityProvider] + let homeserverAddress: String +} + +enum LoginMode { + case unknown + case password + case sso(ssoIdentityProviders: [SSOIdentityProvider]) + case ssoAndPassword(ssoIdentityProviders: [SSOIdentityProvider]) + case unsupported + + var ssoIdentityProviders: [SSOIdentityProvider]? { + switch self { + case .sso(let ssoIdentityProviders), .ssoAndPassword(let ssoIdentityProviders): + return ssoIdentityProviders + default: + return nil + } + } + + var hasSSO: Bool { + switch self { + case .sso, .ssoAndPassword: + return true + default: + return false + } + } + + var supportsSignModeScreen: Bool { + switch self { + case .password, .ssoAndPassword: + return true + case .unknown, .unsupported, .sso: + return false + } + } +} diff --git a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/LoginWizard.swift b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/LoginWizard.swift new file mode 100644 index 0000000000..40fe8098aa --- /dev/null +++ b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/LoginWizard.swift @@ -0,0 +1,21 @@ +// +// Copyright 2022 New Vector Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +class LoginWizard { + // TODO +} diff --git a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/RegistrationModels.swift b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/RegistrationModels.swift new file mode 100644 index 0000000000..e234aa26fe --- /dev/null +++ b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/RegistrationModels.swift @@ -0,0 +1,197 @@ +// +// Copyright 2022 New Vector Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +/// The parameters used for registration requests. +struct RegistrationParameters: Codable { + /// Authentication parameters + var auth: AuthenticationParameters? + + /// The account username + var username: String? + + /// The account password + var password: String? + + /// Device name + var initialDeviceDisplayName: String? + + /// Temporary flag to notify the server that we support MSISDN flow. Used to prevent old app + /// versions to end up in fallback because the HS returns the MSISDN flow which they don't support + var xShowMSISDN: Bool? + + enum CodingKeys: String, CodingKey { + case auth + case username + case password + case initialDeviceDisplayName = "initial_device_display_name" + case xShowMSISDN = "x_show_msisdn" + } + + /// The parameters as a JSON dictionary for use in MXRestClient. + func dictionary() throws -> [String: Any] { + let jsonData = try JSONEncoder().encode(self) + let object = try JSONSerialization.jsonObject(with: jsonData) + guard let dictionary = object as? [String: Any] else { + throw AuthenticationError.dictionaryError + } + + return dictionary + } +} + +/// The data passed to the `auth` parameter in authentication requests. +struct AuthenticationParameters: Codable { + /// The type of authentication taking place. The identifier from `MXLoginFlowType`. + let type: String + + /// Note: session can be null for reset password request + var session: String? + + /// parameter for "m.login.recaptcha" type + var captchaResponse: String? + + /// parameter for "m.login.email.identity" type + var threePIDCredentials: ThreePIDCredentials? + + enum CodingKeys: String, CodingKey { + case type + case session + case captchaResponse = "response" + case threePIDCredentials = "threepid_creds" + } + + /// Creates the authentication parameters for a captcha step. + static func captchaParameters(session: String, captchaResponse: String) -> AuthenticationParameters { + AuthenticationParameters(type: kMXLoginFlowTypeRecaptcha, session: session, captchaResponse: captchaResponse) + } + + /// Creates the authentication parameters for a third party ID step using an email address. + static func emailIdentityParameters(session: String, threePIDCredentials: ThreePIDCredentials) -> AuthenticationParameters { + AuthenticationParameters(type: kMXLoginFlowTypeEmailIdentity, session: session, threePIDCredentials: threePIDCredentials) + } + + // Note that there is a bug in Synapse (needs investigation), but if we pass .msisdn, + // the homeserver answer with the login flow with MatrixError fields and not with a simple MatrixError 401. + /// Creates the authentication parameters for a third party ID step using a phone number. + static func msisdnIdentityParameters(session: String, threePIDCredentials: ThreePIDCredentials) -> AuthenticationParameters { + AuthenticationParameters(type: kMXLoginFlowTypeMSISDN, session: session, threePIDCredentials: threePIDCredentials) + } + + /// Creates the authentication parameters for a password reset step. + static func resetPasswordParameters(clientSecret: String, sessionID: String) -> AuthenticationParameters { + AuthenticationParameters(type: kMXLoginFlowTypeEmailIdentity, + session: nil, + threePIDCredentials: ThreePIDCredentials(clientSecret: clientSecret, sessionID: sessionID)) + } +} + +/// The result from a response of a registration flow step. +enum RegistrationResult { + /// Registration has completed, creating an `MXSession` for the account. + case success(MXSession) + /// The request was successful but there are pending steps to complete. + case flowResponse(FlowResult) +} + +/// The state of an authentication flow after a step has been completed. +struct FlowResult { + /// The stages in the flow that are yet to be completed. + let missingStages: [Stage] + /// The stages in the flow that have been completed. + let completedStages: [Stage] + + /// A stage in the authentication flow. + enum Stage { + /// The stage with the type `m.login.recaptcha`. + case reCaptcha(mandatory: Bool, publicKey: String) + + /// The stage with the type `m.login.email.identity`. + case email(mandatory: Bool) + + /// The stage with the type `m.login.msisdn`. + case msisdn(mandatory: Bool) + + /// The stage with the type `m.login.dummy`. + /// + /// This stage can be mandatory if there is no other stages. In this case the account cannot + /// be created by just sending a username and a password, the dummy stage has to be completed. + case dummy(mandatory: Bool) + + /// The stage with the type `m.login.terms`. + case terms(mandatory: Bool, policies: [String: String]) + + /// A stage of an unknown type. + case other(mandatory: Bool, type: String, params: [AnyHashable: Any]) + + /// Whether the stage is a dummy stage that is also mandatory. + var isDummyAndMandatory: Bool { + guard case let .dummy(isMandatory) = self else { return false } + return isMandatory + } + } +} + +extension MXAuthenticationSession { + /// The flows from the session mapped as a `FlowResult` value. + var flowResult: FlowResult { + let allFlowTypes = Set(flows.flatMap { $0.stages ?? [] }) + var missingStages = [FlowResult.Stage]() + var completedStages = [FlowResult.Stage]() + + allFlowTypes.forEach { flow in + let isMandatory = flows.allSatisfy { $0.stages.contains(flow) } + + let stage: FlowResult.Stage + switch flow { + case kMXLoginFlowTypeRecaptcha: + let parameters = params[flow] as? [AnyHashable: Any] + let publicKey = parameters?["public_key"] as? String + stage = .reCaptcha(mandatory: isMandatory, publicKey: publicKey ?? "") + case kMXLoginFlowTypeDummy: + stage = .dummy(mandatory: isMandatory) + case kMXLoginFlowTypeTerms: + let parameters = params[flow] as? [String: String] + stage = .terms(mandatory: isMandatory, policies: parameters ?? [:]) + case kMXLoginFlowTypeMSISDN: + stage = .msisdn(mandatory: isMandatory) + case kMXLoginFlowTypeEmailIdentity: + stage = .email(mandatory: isMandatory) + default: + let parameters = params[flow] as? [AnyHashable: Any] + stage = .other(mandatory: isMandatory, type: flow, params: parameters ?? [:]) + } + + if let completed = completed, completed.contains(flow) { + completedStages.append(stage) + } else { + missingStages.append(stage) + } + } + + return FlowResult(missingStages: missingStages, completedStages: completedStages) + } + + /// Determines the next stage to be completed in the flow. + func nextUncompletedStage(flowIndex: Int = 0) -> String? { + guard flows.count < flowIndex else { return nil } + return flows[flowIndex].stages.first { + guard let completed = completed else { return false } + return !completed.contains($0) + } + } +} diff --git a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/RegistrationWizard.swift b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/RegistrationWizard.swift new file mode 100644 index 0000000000..dc1ff5391a --- /dev/null +++ b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/RegistrationWizard.swift @@ -0,0 +1,245 @@ +// +// Copyright 2022 New Vector Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +@available(iOS 14.0, *) +/// Set of methods to be able to create an account on a homeserver. +/// +/// Common scenario to register an account successfully: +/// - Call `registrationFlow` to check that you application supports all the mandatory registration stages +/// - Call `createAccount` to start the account creation +/// - Fulfil all mandatory stages using the methods `performReCaptcha` `acceptTerms` `dummy`, etc. +/// +/// More documentation can be found in the file https://github.com/vector-im/element-android/blob/main/docs/signup.md +/// and https://matrix.org/docs/spec/client_server/latest#account-registration-and-management +class RegistrationWizard { + let client: MXRestClient + let sessionCreator: SessionCreator + let pendingData: AuthenticationPendingData + + /// This is the current ThreePID, waiting for validation. The SDK will store it in database, so it can be + /// restored even if the app has been killed during the registration + var currentThreePID: String? { + guard let threePid = pendingData.currentThreePIDData?.threePID else { return nil } + + switch threePid { + case .email(let string): + return string + case .msisdn(let msisdn, _): + return pendingData.currentThreePIDData?.registrationResponse.formattedMSISDN ?? msisdn + } + } + + /// True when login and password have been sent with success to the homeserver, + /// i.e. `createAccount` has been called successfully. + var isRegistrationStarted: Bool { + pendingData.isRegistrationStarted + } + + init(client: MXRestClient, sessionCreator: SessionCreator = SessionCreator(), pendingData: AuthenticationPendingData) { + self.client = client + self.sessionCreator = sessionCreator + self.pendingData = pendingData + } + + /// Call this method to get the possible registration flow of the current homeserver. + /// It can be useful to ensure that your application implementation supports all the stages + /// required to create an account. If it is not the case, you will have to use the web fallback + /// to let the user create an account with your application. + /// See `AuthenticationService.getFallbackUrl` + func registrationFlow() async throws -> RegistrationResult { + let parameters = RegistrationParameters() + return try await performRegistrationRequest(parameters: parameters) + } + + /// Can be call to check is the desired username is available for registration on the current homeserver. + /// It may also fails if the desired username is not correctly formatted or does not follow any restriction on + /// the homeserver. Ex: username with only digits may be rejected. + /// - Parameter username the desired username. Ex: "alice" + func registrationAvailable(username: String) async throws -> Bool { + try await client.isUsernameAvailable(username) + } + + /// This is the first method to call in order to create an account and start the registration process. + /// + /// - Parameter username the desired username. Ex: "alice" + /// - Parameter password the desired password + /// - Parameter initialDeviceDisplayName the device display name + func createAccount(username: String?, + password: String?, + initialDeviceDisplayName: String?) async throws -> RegistrationResult { + let parameters = RegistrationParameters(username: username, password: password, initialDeviceDisplayName: initialDeviceDisplayName) + let result = try await performRegistrationRequest(parameters: parameters) + pendingData.isRegistrationStarted = true + return result + } + + /// Perform the "m.login.recaptcha" stage. + /// + /// - Parameter response: The response from ReCaptcha + func performReCaptcha(response: String) async throws -> RegistrationResult { + guard let session = pendingData.currentSession else { + throw AuthenticationError.createAccountNotCalled + } + + let parameters = RegistrationParameters(auth: AuthenticationParameters.captchaParameters(session: session, captchaResponse: response)) + return try await performRegistrationRequest(parameters: parameters) + } + + /// Perform the "m.login.terms" stage. + func acceptTerms() async throws -> RegistrationResult { + guard let session = pendingData.currentSession else { + throw AuthenticationError.createAccountNotCalled + } + + let parameters = RegistrationParameters(auth: AuthenticationParameters(type: kMXLoginFlowTypeTerms, session: session)) + return try await performRegistrationRequest(parameters: parameters) + } + + /// Perform the "m.login.dummy" stage. + func dummy() async throws -> RegistrationResult { + guard let session = pendingData.currentSession else { + throw AuthenticationError.createAccountNotCalled + } + + let parameters = RegistrationParameters(auth: AuthenticationParameters(type: kMXLoginFlowTypeDummy, session: session)) + return try await performRegistrationRequest(parameters: parameters) + } + + /// Perform the "m.login.email.identity" or "m.login.msisdn" stage. + /// + /// - Parameter threePID the threePID to add to the account. If this is an email, the homeserver will send an email + /// to validate it. For a msisdn a SMS will be sent. + func addThreePID(threePID: RegisterThreePID) async throws -> RegistrationResult { + pendingData.currentThreePIDData = nil + return try await sendThreePID(threePID: threePID) + } + + /// Ask the homeserver to send again the current threePID (email or msisdn). + func sendAgainThreePID() async throws -> RegistrationResult { + guard let threePID = pendingData.currentThreePIDData?.threePID else { + throw AuthenticationError.createAccountNotCalled + } + return try await sendThreePID(threePID: threePID) + } + + /// Send the code received by SMS to validate a msisdn. + /// If the code is correct, the registration request will be executed to validate the msisdn. + func handleValidateThreePID(code: String) async throws -> RegistrationResult { + return try await validateThreePid(code: code) + } + + /// Useful to poll the homeserver when waiting for the email to be validated by the user. + /// Once the email is validated, this method will return successfully. + /// - Parameter delay How long to wait before sending the request. + func checkIfEmailHasBeenValidated(delay: TimeInterval) async throws -> RegistrationResult { + guard let parameters = pendingData.currentThreePIDData?.registrationParameters else { + throw AuthenticationError.noPendingThreePID + } + + return try await performRegistrationRequest(parameters: parameters, delay: delay) + } + + // MARK: - Private + + private func validateThreePid(code: String) async throws -> RegistrationResult { + guard let threePIDData = pendingData.currentThreePIDData else { + throw AuthenticationError.noPendingThreePID + } + + guard let url = threePIDData.registrationResponse.submitURL else { + throw AuthenticationError.missingThreePIDURL + } + + + let validationBody = ThreePIDValidationCodeBody(clientSecret: pendingData.clientSecret, + sessionID: threePIDData.registrationResponse.sessionID, + code: code) + let validationDictionary = try validationBody.dictionary() + + #warning("Seems odd to pass a nil baseURL and then the url as the path, yet this is how MXK3PID works") + guard let httpClient = MXHTTPClient(baseURL: nil, andOnUnrecognizedCertificateBlock: nil) else { + throw AuthenticationError.threePIDClientFailure + } + let responseDictionary = try await httpClient.request(withMethod: "POST", path: url, parameters: validationDictionary) + + // Response is a json dictionary with a single success parameter + if responseDictionary["success"] as? Bool == true { + // The entered code is correct + // Same than validate email + let parameters = threePIDData.registrationParameters + return try await performRegistrationRequest(parameters: parameters, delay: 3) + } else { + // The code is not correct + throw AuthenticationError.threePIDValidationFailure + } + } + + private func sendThreePID(threePID: RegisterThreePID) async throws -> RegistrationResult { + guard let session = pendingData.currentSession else { + throw AuthenticationError.createAccountNotCalled + } + + let response = try await client.requestTokenDuringRegistration(for: threePID, + clientSecret: pendingData.clientSecret, + sendAttempt: pendingData.sendAttempt) + + pendingData.sendAttempt += 1 + + let threePIDCredentials = ThreePIDCredentials(clientSecret: pendingData.clientSecret, sessionID: response.sessionID) + let authenticationParameters: AuthenticationParameters + switch threePID { + case .email: + authenticationParameters = AuthenticationParameters.emailIdentityParameters(session: session, threePIDCredentials: threePIDCredentials) + case .msisdn: + authenticationParameters = AuthenticationParameters.msisdnIdentityParameters(session: session, threePIDCredentials: threePIDCredentials) + } + + let parameters = RegistrationParameters(auth: authenticationParameters) + + pendingData.currentThreePIDData = ThreePIDData(threePID: threePID, registrationResponse: response, registrationParameters: parameters) + + // Send the session id for the first time + return try await performRegistrationRequest(parameters: parameters) + } + + private func performRegistrationRequest(parameters: RegistrationParameters, + delay: TimeInterval = 0) async throws -> RegistrationResult { + try await Task.sleep(nanoseconds: UInt64(delay * 1_000_000_000)) + + let jsonData = try JSONEncoder().encode(parameters) + guard let dictionary = try JSONSerialization.jsonObject(with: jsonData) as? [String: Any] else { + throw MXRestClient.ClientError.decodingError + } + + do { + let response = try await client.register(parameters: dictionary) + let credentials = MXCredentials(loginResponse: response, andDefaultCredentials: client.credentials) + return .success(sessionCreator.createSession(credentials: credentials, client: client)) + } catch { + let nsError = error as NSError + + guard + let jsonResponse = nsError.userInfo[MXHTTPClientErrorResponseDataKey] as? [String: Any], + let authenticationSession = MXAuthenticationSession(fromJSON: jsonResponse) + else { throw error } + + pendingData.currentSession = authenticationSession.session + return .flowResponse(authenticationSession.flowResult) + } + } +} diff --git a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/SessionCreator.swift b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/SessionCreator.swift new file mode 100644 index 0000000000..daa468c605 --- /dev/null +++ b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/SessionCreator.swift @@ -0,0 +1,38 @@ +// +// Copyright 2022 New Vector Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +/// A WIP class that has common functionality to create a new session. +class SessionCreator { + /// Creates an `MXSession` using the supplied credentials and REST client. + func createSession(credentials: MXCredentials, client: MXRestClient) -> MXSession { + // Report the new account in account manager + if credentials.identityServer == nil { + #warning("Check that the client is actually updated with this info?") + credentials.identityServer = client.identityServer + } + + let account = MXKAccount(credentials: credentials) + + if let identityServer = credentials.identityServer { + account.identityServerURL = identityServer + } + + MXKAccountManager.shared().addAccount(account, andOpenSession: true) + return account.mxSession + } +} diff --git a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/ThreePIDModels.swift b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/ThreePIDModels.swift new file mode 100644 index 0000000000..6c8047c5de --- /dev/null +++ b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/ThreePIDModels.swift @@ -0,0 +1,122 @@ +// +// Copyright 2022 New Vector Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +enum RegisterThreePID { + case email(String) + case msisdn(msisdn: String, countryCode: String) +} + +struct ThreePIDCredentials: Codable { + var clientSecret: String? + + var identityServer: String? + + var sessionID: String? + + enum CodingKeys: String, CodingKey { + case clientSecret = "client_secret" + case identityServer = "id_server" + case sessionID = "sid" + } +} + +struct ThreePIDData { + let email: String + let msisdn: String + let country: String + let registrationResponse: RegistrationThreePIDTokenResponse + let registrationParameters: RegistrationParameters + + var threePID: RegisterThreePID { + email.isEmpty ? .msisdn(msisdn: msisdn, countryCode: country) : .email(email) + } +} + +extension ThreePIDData { + init(threePID: RegisterThreePID, + registrationResponse: RegistrationThreePIDTokenResponse, + registrationParameters: RegistrationParameters) { + switch threePID { + case .email(let email): + self.init(email: email, + msisdn: "", + country: "", + registrationResponse: registrationResponse, + registrationParameters: registrationParameters) + case .msisdn(let msisdn, let countryCode): + self.init(email: "", + msisdn: msisdn, + country: countryCode, + registrationResponse: registrationResponse, + registrationParameters: registrationParameters) + } + } +} + +// TODO: This could potentially become an MXJSONModel? +struct RegistrationThreePIDTokenResponse { + /// Required. The session ID. Session IDs are opaque strings that must consist entirely of the characters [0-9a-zA-Z.=_-]. + /// Their length must not exceed 255 characters and they must not be empty. + let sessionID: String + + /// An optional field containing a URL where the client must submit the validation token to, with identical parameters to the Identity + /// Service API's POST /validate/email/submitToken endpoint. The homeserver must send this token to the user (if applicable), + /// who should then be prompted to provide it to the client. + /// + /// If this field is not present, the client can assume that verification will happen without the client's involvement provided + /// the homeserver advertises this specification version in the /versions response (ie: r0.5.0). + var submitURL: String? = nil + + // MARK: - Additional data that may be needed + + var msisdn: String? = nil + var formattedMSISDN: String? = nil + var success: Bool? = nil + + enum CodingKeys: String, CodingKey { + case sessionID = "sid" + case submitURL = "submit_url" + case msisdn + case formattedMSISDN = "intl_fmt" + case success + } +} + +struct ThreePIDValidationCodeBody: Codable { + let clientSecret: String + + let sessionID: String + + let code: String + + enum CodingKeys: String, CodingKey { + case clientSecret = "client_secret" + case sessionID = "sid" + case code = "token" + } + + func dictionary() throws -> [AnyHashable: Any] { + let jsonData = try JSONEncoder().encode(self) + let object = try JSONSerialization.jsonObject(with: jsonData) + guard let dictionary = object as? [AnyHashable: Any] else { + throw AuthenticationError.dictionaryError + } + + return dictionary + } +} diff --git a/RiotSwiftUI/Modules/Common/ErrorHandling/AlertInfo.swift b/RiotSwiftUI/Modules/Common/ErrorHandling/AlertInfo.swift index 509eb81d06..6c442be486 100644 --- a/RiotSwiftUI/Modules/Common/ErrorHandling/AlertInfo.swift +++ b/RiotSwiftUI/Modules/Common/ErrorHandling/AlertInfo.swift @@ -36,12 +36,20 @@ struct AlertInfo: Identifiable { var secondaryButton: (title: String, action: (() -> Void)?)? = nil } -extension AlertInfo where T == Int { +extension AlertInfo { + /// Initialises the type with the title and message from an `NSError` along with the default Ok button. + init?(error: NSError? = nil) where T == Int { + self.init(id: error?.code ?? -1, error: error) + } + /// Initialises the type with the title and message from an `NSError` along with the default Ok button. - init?(error: NSError? = nil) { + /// - Parameters: + /// - id: An ID that identifies the error. + /// - error: The Error that occurred. + init?(id: T, error: NSError? = nil) { guard error?.domain != NSURLErrorDomain && error?.code != NSURLErrorCancelled else { return nil } - id = error?.code ?? -1 + self.id = id title = error?.userInfo[NSLocalizedFailureReasonErrorKey] as? String ?? VectorL10n.error message = error?.userInfo[NSLocalizedDescriptionKey] as? String ?? VectorL10n.errorCommonMessage } From 9f3305d1c6805744cf2887ea1d9a7ecb02063d6f Mon Sep 17 00:00:00 2001 From: Doug Date: Tue, 26 Apr 2022 17:20:09 +0100 Subject: [PATCH 2/5] Update AuthenticationService following PR comments. --- Riot/Categories/MXHTTPClient+Async.swift | 44 ++++- Riot/Categories/MXRestClient+Async.swift | 151 +++++++++--------- .../AuthenticationCoordinatorState.swift | 23 ++- .../Common/AuthenticationModels.swift | 20 ++- .../Common/HomeserverAddress.swift | 6 +- .../MatrixSDK/AuthenticationPendingData.swift | 1 + .../MatrixSDK/AuthenticationService.swift | 88 +++++----- .../Service/MatrixSDK/LoginModels.swift | 13 ++ .../MatrixSDK/RegistrationWizard.swift | 53 +++--- .../Service/MatrixSDK/ThreePIDModels.swift | 39 +---- 10 files changed, 224 insertions(+), 214 deletions(-) diff --git a/Riot/Categories/MXHTTPClient+Async.swift b/Riot/Categories/MXHTTPClient+Async.swift index 93ac419f93..ab5c87a521 100644 --- a/Riot/Categories/MXHTTPClient+Async.swift +++ b/Riot/Categories/MXHTTPClient+Async.swift @@ -26,17 +26,51 @@ extension MXHTTPClient { case unknownError } + /// Validates a third party ID code at the given URL. + func validateThreePIDCode(submitURL: String, validationBody: ThreePIDValidationCodeBody) async throws -> Bool { + let data = try validationBody.jsonData() + let responseDictionary = try await request(withMethod: "POST", path: submitURL, parameters: nil, data: data) + + // Response is a json dictionary with a single success parameter + guard let success = responseDictionary["success"] as? Bool else { + throw ClientError.invalidResponse + } + + return success + } + /// An async version of `request(withMethod:path:parameters:success:failure:)`. - func request(withMethod method: String, path: String, parameters: [AnyHashable: Any]) async throws -> [AnyHashable: Any] { + func request(withMethod method: String, + path: String, + parameters: [AnyHashable: Any]?, + needsAuthentication: Bool? = nil, + data: Data? = nil, + headers: [AnyHashable: Any]? = nil, + timeout: TimeInterval = -1) async throws -> [AnyHashable: Any] { + try await getResponse { success, failure in + request(withMethod: method, + path: path, + parameters: parameters, + needsAuthentication: needsAuthentication ?? isAuthenticatedClient, + data: data, + headers: headers, + timeout: timeout, + uploadProgress: nil, + success: success, + failure: failure) + } + } + + private func getResponse(_ callback: (@escaping (T?) -> Void, @escaping (Error?) -> Void) -> MXHTTPOperation) async throws -> T { try await withCheckedThrowingContinuation { continuation in - request(withMethod: method, path: path, parameters: parameters) { jsonDictionary in - guard let jsonDictionary = jsonDictionary else { + _ = callback { response in + guard let response = response else { continuation.resume(with: .failure(ClientError.invalidResponse)) return } - continuation.resume(with: .success(jsonDictionary)) - } failure: { error in + continuation.resume(with: .success(response)) + } _: { error in continuation.resume(with: .failure(error ?? ClientError.unknownError)) } } diff --git a/Riot/Categories/MXRestClient+Async.swift b/Riot/Categories/MXRestClient+Async.swift index e58ff06af6..c99289002f 100644 --- a/Riot/Categories/MXRestClient+Async.swift +++ b/Riot/Categories/MXRestClient+Async.swift @@ -30,112 +30,113 @@ extension MXRestClient { /// An async version of `wellKnow(_:failure:)`. func wellKnown() async throws -> MXWellKnown { - try await withCheckedThrowingContinuation { continuation in - wellKnow { wellKnown in - guard let wellKnown = wellKnown else { - continuation.resume(with: .failure(ClientError.invalidResponse)) - return - } - - continuation.resume(with: .success(wellKnown)) - } failure: { error in - continuation.resume(with: .failure(error ?? ClientError.unknownError)) - } + try await getResponse { success, failure in + wellKnow(success, failure: failure) } } /// An async version of `getRegisterSession(completion:)`. func getRegisterSession() async throws -> MXAuthenticationSession { - try await withCheckedThrowingContinuation { continuation in - getRegisterSession { response in - guard let session = response.value else { - continuation.resume(with: .failure(response.error ?? ClientError.unknownError)) - return - } - - continuation.resume(with: .success(session)) - } - } + try await getResponse(getRegisterSession) } /// An async version of `getLoginSession(completion:)`. func getLoginSession() async throws -> MXAuthenticationSession { - try await withCheckedThrowingContinuation { continuation in - getLoginSession { response in - guard let session = response.value else { - continuation.resume(with: .failure(response.error ?? ClientError.unknownError)) - return - } - - continuation.resume(with: .success(session)) - } - } + try await getResponse(getLoginSession) } /// An async version of `isUsernameAvailable(_:completion:)`. func isUsernameAvailable(_ username: String) async throws -> Bool { - try await withCheckedThrowingContinuation { continuation in - isUsernameAvailable(username) { response in - guard let availability = response.value else { - continuation.resume(with: .failure(response.error ?? ClientError.unknownError)) - return - } - - continuation.resume(with: .success(availability.available)) - } + let availability = try await getResponse { completion in + isUsernameAvailable(username, completion: completion) } + return availability.available + } + + /// An async version of `register(parameters:completion:)`, that takes a `RegistrationParameters` value instead of a dictionary. + func register(parameters: RegistrationParameters) async throws -> MXLoginResponse { + let dictionary = try parameters.dictionary() + return try await register(parameters: dictionary) } /// An async version of `register(parameters:completion:)`. func register(parameters: [String: Any]) async throws -> MXLoginResponse { + let jsonDictionary = try await getResponse { completion in + register(parameters: parameters, completion: completion) + } + guard let loginResponse = MXLoginResponse(fromJSON: jsonDictionary) else { throw ClientError.decodingError } + return loginResponse + } + + /// An async version of both `requestToken(forEmail:isDuringRegistration:clientSecret:sendAttempt:nextLink:success:failure:)` and + /// `requestToken(forPhoneNumber:isDuringRegistration:countryCode:clientSecret:sendAttempt:nextLink:success:failure:)` depending + /// on the kind of third party ID is supplied to the `threePID` parameter. + func requestTokenDuringRegistration(for threePID: RegisterThreePID, clientSecret: String, sendAttempt: UInt) async throws -> RegistrationThreePIDTokenResponse { + switch threePID { + case .email(let email): + let sessionID: String = try await getResponse { success, failure in + requestToken(forEmail: email, + isDuringRegistration: true, + clientSecret: clientSecret, + sendAttempt: sendAttempt, + nextLink: nil, + success: success, + failure: failure) + } + + return RegistrationThreePIDTokenResponse(sessionID: sessionID) + case .msisdn(let msisdn, let countryCode): + let (sessionID, msisdn, submitURL): (String?, String?, String?) = try await getResponse { success, failure in + requestToken(forPhoneNumber: msisdn, + isDuringRegistration: true, + countryCode: countryCode, + clientSecret: clientSecret, + sendAttempt: sendAttempt, + nextLink: nil, + success: success, + failure: failure) + } + guard let sessionID = sessionID else { throw ClientError.invalidResponse } + return RegistrationThreePIDTokenResponse(sessionID: sessionID, submitURL: submitURL, msisdn: msisdn) + } + } + + // MARK: Private + + private func getResponse(_ callback: (@escaping (MXResponse) -> Void) -> MXHTTPOperation) async throws -> T { try await withCheckedThrowingContinuation { continuation in - register(parameters: parameters) { response in - guard let jsonDictionary = response.value else { + _ = callback { response in + guard let value = response.value else { continuation.resume(with: .failure(response.error ?? ClientError.unknownError)) return } - guard let loginResponse = MXLoginResponse(fromJSON: jsonDictionary) else { - continuation.resume(with: .failure(ClientError.decodingError)) + continuation.resume(with: .success(value)) + } + } + } + + private func getResponse(_ callback: (@escaping (T?) -> Void, @escaping (Error?) -> Void) -> MXHTTPOperation) async throws -> T { + try await withCheckedThrowingContinuation { continuation in + _ = callback { response in + guard let response = response else { + continuation.resume(with: .failure(ClientError.invalidResponse)) return } - continuation.resume(with: .success(loginResponse)) + continuation.resume(with: .success(response)) + } _: { error in + continuation.resume(with: .failure(error ?? ClientError.unknownError)) } } } - /// An async version of both `requestToken(forEmail:isDuringRegistration:clientSecret:sendAttempt:nextLink:success:failure:)` and - /// `requestToken(forPhoneNumber:isDuringRegistration:countryCode:clientSecret:sendAttempt:nextLink:success:failure:)` depending - /// on the kind of third party ID is supplied to the `threePID` parameter. - func requestTokenDuringRegistration(for threePID: RegisterThreePID, clientSecret: String, sendAttempt: UInt) async throws -> RegistrationThreePIDTokenResponse { + private func getResponse(_ callback: (@escaping (T?, U?, V?) -> Void, @escaping (Error?) -> Void) -> MXHTTPOperation) async throws -> (T?, U?, V?) { try await withCheckedThrowingContinuation { continuation in - switch threePID { - case .email(let email): - requestToken(forEmail: email, isDuringRegistration: true, clientSecret: clientSecret, sendAttempt: sendAttempt, nextLink: nil) { sessionID in - guard let sessionID = sessionID else { - continuation.resume(with: .failure(ClientError.invalidResponse)) - return - } - - let response = RegistrationThreePIDTokenResponse(sessionID: sessionID) - continuation.resume(with: .success(response)) - } failure: { error in - continuation.resume(with: .failure(error ?? ClientError.unknownError)) - } - - case .msisdn(let msisdn, let countryCode): - requestToken(forPhoneNumber: msisdn, isDuringRegistration: true, countryCode: countryCode, clientSecret: clientSecret, sendAttempt: sendAttempt, nextLink: nil) { sessionID, msisdn, submitURL in - guard let sessionID = sessionID else { - continuation.resume(with: .failure(ClientError.invalidResponse)) - return - } - - let response = RegistrationThreePIDTokenResponse(sessionID: sessionID, submitURL: submitURL, msisdn: msisdn) - continuation.resume(with: .success(response)) - } failure: { error in - continuation.resume(with: .failure(error ?? ClientError.unknownError)) - } + _ = callback { arg1, arg2, arg3 in + continuation.resume(with: .success((arg1, arg2, arg3))) + } _: { error in + continuation.resume(with: .failure(error ?? ClientError.unknownError)) } } } diff --git a/Riot/Modules/Onboarding/AuthenticationCoordinatorState.swift b/Riot/Modules/Onboarding/AuthenticationCoordinatorState.swift index 9eebfda03f..c5bca72a12 100644 --- a/Riot/Modules/Onboarding/AuthenticationCoordinatorState.swift +++ b/Riot/Modules/Onboarding/AuthenticationCoordinatorState.swift @@ -24,18 +24,25 @@ struct AuthenticationCoordinatorState { // var signMode: SignMode = .unknown var resetPasswordEmail: String? - /// The homeserver address as returned by the server. - var homeserverAddress: String? - /// The homeserver address as input by the user (it can differ to the well-known request). - var homeserverAddressFromUser: String? + /// Information about the currently selected homeserver. + var selectedHomeserver: SelectedHomeserver /// For SSO session recovery var deviceId: String? - // MARK: Network result - var loginMode: LoginMode = .unknown - /// Supported types for the login. - var loginModeSupportedTypes = [MXLoginFlow]() var knownCustomHomeServersUrls = [String]() var isForceLoginFallbackEnabled = false + + struct SelectedHomeserver { + /// The homeserver address as returned by the server. + var address: String + /// The homeserver address as input by the user (it can differ to the well-known request). + var addressFromUser: String? + + /// The preferred login mode for the server + var preferredLoginMode: LoginMode = .unknown + /// Supported types for the login. + var loginModeSupportedTypes = [MXLoginFlow]() + } + } diff --git a/RiotSwiftUI/Modules/Authentication/Common/AuthenticationModels.swift b/RiotSwiftUI/Modules/Authentication/Common/AuthenticationModels.swift index 5435051c24..fa3bbd1fd7 100644 --- a/RiotSwiftUI/Modules/Authentication/Common/AuthenticationModels.swift +++ b/RiotSwiftUI/Modules/Authentication/Common/AuthenticationModels.swift @@ -16,17 +16,24 @@ import Foundation -/// Errors that can be thrown from `AuthenticationService`, `RegistrationWizard` and `LoginWizard`. +/// A value that dictates the authentication flow that will be used. +enum AuthenticationMode { + case login + case registration +} + +/// Errors that can be thrown from `AuthenticationService`. enum AuthenticationError: String, Error { - // MARK: AuthenticationService /// A failure to convert a struct into a dictionary. case dictionaryError case invalidHomeserver case loginFlowNotCalled case missingRegistrationWizard case missingMXRestClient - - // MARK: RegistrationWizard +} + +/// Errors that can be thrown from `RegistrationWizard` +enum RegistrationError: String, Error { case createAccountNotCalled case noPendingThreePID case missingThreePIDURL @@ -34,6 +41,11 @@ enum AuthenticationError: String, Error { case threePIDClientFailure } +/// Errors that can be thrown from `LoginWizard` +enum LoginError: String, Error { + case unimplemented +} + /// Represents an SSO Identity Provider as provided in a login flow. struct SSOIdentityProvider: Identifiable { /// The identifier field (id field in JSON) is the Identity Provider identifier used for the SSO Web page redirection `/login/sso/redirect/{idp_id}`. diff --git a/RiotSwiftUI/Modules/Authentication/Common/HomeserverAddress.swift b/RiotSwiftUI/Modules/Authentication/Common/HomeserverAddress.swift index c65de672a3..d0bca43e08 100644 --- a/RiotSwiftUI/Modules/Authentication/Common/HomeserverAddress.swift +++ b/RiotSwiftUI/Modules/Authentication/Common/HomeserverAddress.swift @@ -16,13 +16,15 @@ import Foundation -class HomeserverAddress { +struct HomeserverAddress { /// Ensures the address contains a scheme, otherwise makes it `https`. - static func sanitize(_ address: String) -> String { + static func sanitized(_ address: String) -> String { !address.contains("://") ? "https://\(address.lowercased())" : address.lowercased() } /// Strips the `https://` away from the address (but leaves `http://`) for display in labels. + /// + /// `http://` is left in the string to make it clear when a chosen server doesn't use SSL. static func displayable(_ address: String) -> String { address.replacingOccurrences(of: "https://", with: "") } diff --git a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationPendingData.swift b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationPendingData.swift index 2aa4bc4992..ec6a0d8e5b 100644 --- a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationPendingData.swift +++ b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationPendingData.swift @@ -33,6 +33,7 @@ class AuthenticationPendingData { var currentSession: String? var isRegistrationStarted = false + var currentRegistrationResult: RegistrationResult? var currentThreePIDData: ThreePIDData? // MARK: - Setup diff --git a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationService.swift b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationService.swift index 53692461f3..9bd61629f6 100644 --- a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationService.swift +++ b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationService.swift @@ -46,13 +46,13 @@ class AuthenticationService: NSObject { /// The address of the homeserver that the service is using. var homeserverAddress: String { - state.homeserverAddress ?? RiotSettings.shared.homeserverUrlString + state.selectedHomeserver.address } // MARK: Android OnboardingViewModel /// The current state of the authentication flow. - private var state = AuthenticationCoordinatorState() + private var state: AuthenticationCoordinatorState /// The currently executing async task. private var currentTask: Task? { willSet { @@ -64,62 +64,53 @@ class AuthenticationService: NSObject { // MARK: - Setup override init() { - guard let homeserverURL = URL(string: RiotSettings.shared.homeserverUrlString) else { - fatalError("[AuthenticationService]: Failed to create URL from default homeserver URL string.") + if let homeserverURL = URL(string: RiotSettings.shared.homeserverUrlString) { + // Use the same homeserver that was last used. + state = AuthenticationCoordinatorState(selectedHomeserver: .init(address: RiotSettings.shared.homeserverUrlString)) + client = MXRestClient(homeServer: homeserverURL, unrecognizedCertificateHandler: nil) + + } else if let homeserverURL = URL(string: BuildSettings.serverConfigDefaultHomeserverUrlString) { + // Fall back to the default homeserver if the stored one is invalid. + state = AuthenticationCoordinatorState(selectedHomeserver: .init(address: BuildSettings.serverConfigDefaultHomeserverUrlString)) + client = MXRestClient(homeServer: homeserverURL, unrecognizedCertificateHandler: nil) + + } else { + MXLog.failure("[AuthenticationService]: Failed to create URL from default homeserver URL string.") + fatalError("Invalid default homeserver URL string.") } - client = MXRestClient(homeServer: homeserverURL, unrecognizedCertificateHandler: nil) - super.init() } // MARK: - Android OnboardingViewModel - func loginFlow(homeserverAddress: String) async { + func startFlow(for homeserverAddress: String, as authenticationMode: AuthenticationMode) async throws { currentTask = Task { cancelPendingLoginOrRegistration() - do { - let data = try await loginFlow(for: homeserverAddress) - - guard !Task.isCancelled else { return } - - // Valid Homeserver, add it to the history. - // Note: we add what the user has input, as the data can contain a different value. - RiotSettings.shared.homeserverUrlString = homeserverAddress - - let loginMode: LoginMode - - if data.supportedLoginTypes.contains(where: { $0.type == kMXLoginFlowTypeSSO }), - data.supportedLoginTypes.contains(where: { $0.type == kMXLoginFlowTypePassword }) { - loginMode = .ssoAndPassword(ssoIdentityProviders: data.ssoIdentityProviders) - } else if data.supportedLoginTypes.contains(where: { $0.type == kMXLoginFlowTypeSSO }) { - loginMode = .sso(ssoIdentityProviders: data.ssoIdentityProviders) - } else if data.supportedLoginTypes.contains(where: { $0.type == kMXLoginFlowTypePassword }) { - loginMode = .password - } else { - loginMode = .unsupported - } - - state.homeserverAddressFromUser = homeserverAddress - state.homeserverAddress = data.homeserverAddress - state.loginMode = loginMode - state.loginModeSupportedTypes = data.supportedLoginTypes - } catch { - #warning("Show an error message and/or handle the error?") - return + let loginFlows = try await loginFlow(for: homeserverAddress) + + var registrationFlow: RegistrationResult? + if authenticationMode == .registration { + let wizard = try registrationWizard() + registrationFlow = try await wizard.registrationFlow() } + + guard !Task.isCancelled else { return } + + // Valid Homeserver, add it to the history. + // Note: we add what the user has input, as the data can contain a different value. + RiotSettings.shared.homeserverUrlString = homeserverAddress + + state.selectedHomeserver = .init(address: loginFlows.homeserverAddress, + addressFromUser: homeserverAddress, + preferredLoginMode: loginFlows.loginMode, + loginModeSupportedTypes: loginFlows.supportedLoginTypes) + + pendingData?.currentRegistrationResult = registrationFlow } - } - - func refreshServer(homeserverAddress: String) async throws -> (LoginFlowResult, RegistrationResult) { - let loginFlows = try await loginFlow(for: homeserverAddress) - let wizard = try registrationWizard() - let registrationFlow = try await wizard.registrationFlow() - state.homeserverAddress = homeserverAddress - - return (loginFlows, registrationFlow) + try await currentTask?.value } // MARK: - Public @@ -148,18 +139,13 @@ class AuthenticationService: NSObject { MXKAccountManager.shared().activeAccounts?.first?.mxSession } - enum AuthenticationMode { - case login - case registration - } - /// Request the supported login flows for this homeserver. /// This is the first method to call to be able to get a wizard to login or to create an account /// - Parameter homeserverAddress: The homeserver string entered by the user. func loginFlow(for homeserverAddress: String) async throws -> LoginFlowResult { pendingData = nil - let homeserverAddress = HomeserverAddress.sanitize(homeserverAddress) + let homeserverAddress = HomeserverAddress.sanitized(homeserverAddress) guard var homeserverURL = URL(string: homeserverAddress) else { throw AuthenticationError.invalidHomeserver diff --git a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/LoginModels.swift b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/LoginModels.swift index 22108de730..7e1e4b9001 100644 --- a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/LoginModels.swift +++ b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/LoginModels.swift @@ -20,6 +20,19 @@ struct LoginFlowResult { let supportedLoginTypes: [MXLoginFlow] let ssoIdentityProviders: [SSOIdentityProvider] let homeserverAddress: String + + var loginMode: LoginMode { + if supportedLoginTypes.contains(where: { $0.type == kMXLoginFlowTypeSSO }), + supportedLoginTypes.contains(where: { $0.type == kMXLoginFlowTypePassword }) { + return .ssoAndPassword(ssoIdentityProviders: ssoIdentityProviders) + } else if supportedLoginTypes.contains(where: { $0.type == kMXLoginFlowTypeSSO }) { + return .sso(ssoIdentityProviders: ssoIdentityProviders) + } else if supportedLoginTypes.contains(where: { $0.type == kMXLoginFlowTypePassword }) { + return .password + } else { + return .unsupported + } + } } enum LoginMode { diff --git a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/RegistrationWizard.swift b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/RegistrationWizard.swift index dc1ff5391a..bf7af84203 100644 --- a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/RegistrationWizard.swift +++ b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/RegistrationWizard.swift @@ -93,7 +93,7 @@ class RegistrationWizard { /// - Parameter response: The response from ReCaptcha func performReCaptcha(response: String) async throws -> RegistrationResult { guard let session = pendingData.currentSession else { - throw AuthenticationError.createAccountNotCalled + throw RegistrationError.createAccountNotCalled } let parameters = RegistrationParameters(auth: AuthenticationParameters.captchaParameters(session: session, captchaResponse: response)) @@ -103,7 +103,7 @@ class RegistrationWizard { /// Perform the "m.login.terms" stage. func acceptTerms() async throws -> RegistrationResult { guard let session = pendingData.currentSession else { - throw AuthenticationError.createAccountNotCalled + throw RegistrationError.createAccountNotCalled } let parameters = RegistrationParameters(auth: AuthenticationParameters(type: kMXLoginFlowTypeTerms, session: session)) @@ -113,7 +113,7 @@ class RegistrationWizard { /// Perform the "m.login.dummy" stage. func dummy() async throws -> RegistrationResult { guard let session = pendingData.currentSession else { - throw AuthenticationError.createAccountNotCalled + throw RegistrationError.createAccountNotCalled } let parameters = RegistrationParameters(auth: AuthenticationParameters(type: kMXLoginFlowTypeDummy, session: session)) @@ -132,7 +132,7 @@ class RegistrationWizard { /// Ask the homeserver to send again the current threePID (email or msisdn). func sendAgainThreePID() async throws -> RegistrationResult { guard let threePID = pendingData.currentThreePIDData?.threePID else { - throw AuthenticationError.createAccountNotCalled + throw RegistrationError.createAccountNotCalled } return try await sendThreePID(threePID: threePID) } @@ -147,51 +147,46 @@ class RegistrationWizard { /// Once the email is validated, this method will return successfully. /// - Parameter delay How long to wait before sending the request. func checkIfEmailHasBeenValidated(delay: TimeInterval) async throws -> RegistrationResult { + MXLog.failure("The delay on this method is no longer available. Move this to the object handling the polling.") guard let parameters = pendingData.currentThreePIDData?.registrationParameters else { - throw AuthenticationError.noPendingThreePID + throw RegistrationError.noPendingThreePID } - return try await performRegistrationRequest(parameters: parameters, delay: delay) + return try await performRegistrationRequest(parameters: parameters) } // MARK: - Private private func validateThreePid(code: String) async throws -> RegistrationResult { guard let threePIDData = pendingData.currentThreePIDData else { - throw AuthenticationError.noPendingThreePID + throw RegistrationError.noPendingThreePID } - guard let url = threePIDData.registrationResponse.submitURL else { - throw AuthenticationError.missingThreePIDURL + guard let submitURL = threePIDData.registrationResponse.submitURL else { + throw RegistrationError.missingThreePIDURL } let validationBody = ThreePIDValidationCodeBody(clientSecret: pendingData.clientSecret, sessionID: threePIDData.registrationResponse.sessionID, code: code) - let validationDictionary = try validationBody.dictionary() #warning("Seems odd to pass a nil baseURL and then the url as the path, yet this is how MXK3PID works") guard let httpClient = MXHTTPClient(baseURL: nil, andOnUnrecognizedCertificateBlock: nil) else { - throw AuthenticationError.threePIDClientFailure + throw RegistrationError.threePIDClientFailure } - let responseDictionary = try await httpClient.request(withMethod: "POST", path: url, parameters: validationDictionary) - - // Response is a json dictionary with a single success parameter - if responseDictionary["success"] as? Bool == true { - // The entered code is correct - // Same than validate email - let parameters = threePIDData.registrationParameters - return try await performRegistrationRequest(parameters: parameters, delay: 3) - } else { - // The code is not correct - throw AuthenticationError.threePIDValidationFailure + guard try await httpClient.validateThreePIDCode(submitURL: submitURL, validationBody: validationBody) else { + throw RegistrationError.threePIDValidationFailure } + + let parameters = threePIDData.registrationParameters + MXLog.failure("This method used to add a 3-second delay to the request. This should be moved to the caller of `handleValidateThreePID`.") + return try await performRegistrationRequest(parameters: parameters) } private func sendThreePID(threePID: RegisterThreePID) async throws -> RegistrationResult { guard let session = pendingData.currentSession else { - throw AuthenticationError.createAccountNotCalled + throw RegistrationError.createAccountNotCalled } let response = try await client.requestTokenDuringRegistration(for: threePID, @@ -217,17 +212,9 @@ class RegistrationWizard { return try await performRegistrationRequest(parameters: parameters) } - private func performRegistrationRequest(parameters: RegistrationParameters, - delay: TimeInterval = 0) async throws -> RegistrationResult { - try await Task.sleep(nanoseconds: UInt64(delay * 1_000_000_000)) - - let jsonData = try JSONEncoder().encode(parameters) - guard let dictionary = try JSONSerialization.jsonObject(with: jsonData) as? [String: Any] else { - throw MXRestClient.ClientError.decodingError - } - + private func performRegistrationRequest(parameters: RegistrationParameters) async throws -> RegistrationResult { do { - let response = try await client.register(parameters: dictionary) + let response = try await client.register(parameters: parameters) let credentials = MXCredentials(loginResponse: response, andDefaultCredentials: client.credentials) return .success(sessionCreator.createSession(credentials: credentials, client: client)) } catch { diff --git a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/ThreePIDModels.swift b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/ThreePIDModels.swift index 6c8047c5de..57682a88f2 100644 --- a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/ThreePIDModels.swift +++ b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/ThreePIDModels.swift @@ -36,36 +36,9 @@ struct ThreePIDCredentials: Codable { } struct ThreePIDData { - let email: String - let msisdn: String - let country: String + let threePID: RegisterThreePID let registrationResponse: RegistrationThreePIDTokenResponse let registrationParameters: RegistrationParameters - - var threePID: RegisterThreePID { - email.isEmpty ? .msisdn(msisdn: msisdn, countryCode: country) : .email(email) - } -} - -extension ThreePIDData { - init(threePID: RegisterThreePID, - registrationResponse: RegistrationThreePIDTokenResponse, - registrationParameters: RegistrationParameters) { - switch threePID { - case .email(let email): - self.init(email: email, - msisdn: "", - country: "", - registrationResponse: registrationResponse, - registrationParameters: registrationParameters) - case .msisdn(let msisdn, let countryCode): - self.init(email: "", - msisdn: msisdn, - country: countryCode, - registrationResponse: registrationResponse, - registrationParameters: registrationParameters) - } - } } // TODO: This could potentially become an MXJSONModel? @@ -110,13 +83,7 @@ struct ThreePIDValidationCodeBody: Codable { case code = "token" } - func dictionary() throws -> [AnyHashable: Any] { - let jsonData = try JSONEncoder().encode(self) - let object = try JSONSerialization.jsonObject(with: jsonData) - guard let dictionary = object as? [AnyHashable: Any] else { - throw AuthenticationError.dictionaryError - } - - return dictionary + func jsonData() throws -> Data { + try JSONEncoder().encode(self) } } From 6b56d3b72df8574ea88c71cdfbc815841c60c0ab Mon Sep 17 00:00:00 2001 From: Doug Date: Wed, 27 Apr 2022 11:08:52 +0100 Subject: [PATCH 3/5] Simplify Authentication state with individual structs. Add tests for AuthenticationService. --- .../Common/AuthenticationModels.swift | 2 +- .../MatrixSDK/AuthenticationPendingData.swift | 44 ----- .../MatrixSDK/AuthenticationService.swift | 155 +++++------------- .../MatrixSDK/AuthenticationState.swift | 22 +-- .../Service/MatrixSDK/LoginWizard.swift | 10 ++ .../MatrixSDK/RegistrationWizard.swift | 59 ++++--- RiotTests/AuthenticationServiceTests.swift | 83 ++++++++++ 7 files changed, 181 insertions(+), 194 deletions(-) delete mode 100644 RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationPendingData.swift rename Riot/Modules/Onboarding/AuthenticationCoordinatorState.swift => RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationState.swift (72%) create mode 100644 RiotTests/AuthenticationServiceTests.swift diff --git a/RiotSwiftUI/Modules/Authentication/Common/AuthenticationModels.swift b/RiotSwiftUI/Modules/Authentication/Common/AuthenticationModels.swift index fa3bbd1fd7..1900f9f245 100644 --- a/RiotSwiftUI/Modules/Authentication/Common/AuthenticationModels.swift +++ b/RiotSwiftUI/Modules/Authentication/Common/AuthenticationModels.swift @@ -35,7 +35,7 @@ enum AuthenticationError: String, Error { /// Errors that can be thrown from `RegistrationWizard` enum RegistrationError: String, Error { case createAccountNotCalled - case noPendingThreePID + case missingThreePIDData case missingThreePIDURL case threePIDValidationFailure case threePIDClientFailure diff --git a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationPendingData.swift b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationPendingData.swift deleted file mode 100644 index ec6a0d8e5b..0000000000 --- a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationPendingData.swift +++ /dev/null @@ -1,44 +0,0 @@ -// -// Copyright 2022 New Vector Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation - -/// This class holds all pending data when creating a session, either by login or by register -class AuthenticationPendingData { - let homeserverAddress: String - - // MARK: - Common - - var clientSecret = UUID().uuidString - var sendAttempt: UInt = 0 - - // MARK: - For login - - // var resetPasswordData: ResetPasswordData? - - // MARK: - For registration - - var currentSession: String? - var isRegistrationStarted = false - var currentRegistrationResult: RegistrationResult? - var currentThreePIDData: ThreePIDData? - - // MARK: - Setup - - init(homeserverAddress: String) { - self.homeserverAddress = homeserverAddress - } -} diff --git a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationService.swift b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationService.swift index 9bd61629f6..15552b3217 100644 --- a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationService.swift +++ b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationService.swift @@ -33,45 +33,29 @@ class AuthenticationService: NSObject { /// The rest client used to make authentication requests. private var client: MXRestClient - /// Pending data collected as the authentication flow progresses. - private var pendingData: AuthenticationPendingData? - /// The current registration wizard or `nil` if `registrationWizard()` hasn't been called. - private var currentRegistrationWizard: RegistrationWizard? - /// The current login wizard or `nil` if `loginWizard()` hasn't been called. - private var currentLoginWizard: LoginWizard? /// The object used to create a new `MXSession` when authentication has completed. private var sessionCreator = SessionCreator() // MARK: Public - /// The address of the homeserver that the service is using. - var homeserverAddress: String { - state.selectedHomeserver.address - } - - - // MARK: Android OnboardingViewModel /// The current state of the authentication flow. - private var state: AuthenticationCoordinatorState - /// The currently executing async task. - private var currentTask: Task? { - willSet { - currentTask?.cancel() - } - } - + private(set) var state: AuthenticationState + /// The current login wizard or `nil` if `startFlow` hasn't been called. + private(set) var loginWizard: LoginWizard? + /// The current registration wizard or `nil` if `startFlow` hasn't been called for `.registration`. + private(set) var registrationWizard: RegistrationWizard? // MARK: - Setup override init() { if let homeserverURL = URL(string: RiotSettings.shared.homeserverUrlString) { // Use the same homeserver that was last used. - state = AuthenticationCoordinatorState(selectedHomeserver: .init(address: RiotSettings.shared.homeserverUrlString)) + state = AuthenticationState(authenticationMode: .login, homeserverAddress: RiotSettings.shared.homeserverUrlString) client = MXRestClient(homeServer: homeserverURL, unrecognizedCertificateHandler: nil) } else if let homeserverURL = URL(string: BuildSettings.serverConfigDefaultHomeserverUrlString) { // Fall back to the default homeserver if the stored one is invalid. - state = AuthenticationCoordinatorState(selectedHomeserver: .init(address: BuildSettings.serverConfigDefaultHomeserverUrlString)) + state = AuthenticationState(authenticationMode: .login, homeserverAddress: BuildSettings.serverConfigDefaultHomeserverUrlString) client = MXRestClient(homeServer: homeserverURL, unrecognizedCertificateHandler: nil) } else { @@ -85,32 +69,27 @@ class AuthenticationService: NSObject { // MARK: - Android OnboardingViewModel func startFlow(for homeserverAddress: String, as authenticationMode: AuthenticationMode) async throws { - currentTask = Task { - cancelPendingLoginOrRegistration() - - let loginFlows = try await loginFlow(for: homeserverAddress) - - var registrationFlow: RegistrationResult? - if authenticationMode == .registration { - let wizard = try registrationWizard() - registrationFlow = try await wizard.registrationFlow() - } - - guard !Task.isCancelled else { return } - - // Valid Homeserver, add it to the history. - // Note: we add what the user has input, as the data can contain a different value. - RiotSettings.shared.homeserverUrlString = homeserverAddress - - state.selectedHomeserver = .init(address: loginFlows.homeserverAddress, - addressFromUser: homeserverAddress, - preferredLoginMode: loginFlows.loginMode, - loginModeSupportedTypes: loginFlows.supportedLoginTypes) - - pendingData?.currentRegistrationResult = registrationFlow - } + reset() + + let loginFlows = try await loginFlow(for: homeserverAddress) + + // Valid Homeserver, add it to the history. + // Note: we add what the user has input, as the data can contain a different value. + RiotSettings.shared.homeserverUrlString = homeserverAddress - try await currentTask?.value + state.homeserver = .init(address: loginFlows.homeserverAddress, + addressFromUser: homeserverAddress, + preferredLoginMode: loginFlows.loginMode, + loginModeSupportedTypes: loginFlows.supportedLoginTypes) + + let loginWizard = LoginWizard() + self.loginWizard = loginWizard + + if authenticationMode == .registration { + let registrationWizard = RegistrationWizard(client: client) + state.initialRegistrationFlow = try await registrationWizard.registrationFlow() + self.registrationWizard = registrationWizard + } } // MARK: - Public @@ -142,16 +121,14 @@ class AuthenticationService: NSObject { /// Request the supported login flows for this homeserver. /// This is the first method to call to be able to get a wizard to login or to create an account /// - Parameter homeserverAddress: The homeserver string entered by the user. - func loginFlow(for homeserverAddress: String) async throws -> LoginFlowResult { - pendingData = nil - + private func loginFlow(for homeserverAddress: String) async throws -> LoginFlowResult { let homeserverAddress = HomeserverAddress.sanitized(homeserverAddress) guard var homeserverURL = URL(string: homeserverAddress) else { throw AuthenticationError.invalidHomeserver } - let pendingData = AuthenticationPendingData(homeserverAddress: homeserverAddress) + let state = AuthenticationState(authenticationMode: .login, homeserverAddress: homeserverAddress) if let wellKnown = try? await wellKnown(for: homeserverURL), let baseURL = URL(string: wellKnown.homeServer.baseUrl) { @@ -164,7 +141,7 @@ class AuthenticationService: NSObject { let loginFlow = try await getLoginFlowResult(client: client) self.client = client - self.pendingData = pendingData + self.state = state return loginFlow } @@ -172,16 +149,14 @@ class AuthenticationService: NSObject { /// Request the supported login flows for the corresponding session. /// This method is used to get the flows for a server after a soft-logout. /// - Parameter session: The MXSession where a soft-logout has occurred. - func loginFlow(for session: MXSession) async throws -> LoginFlowResult { - pendingData = nil - + private func loginFlow(for session: MXSession) async throws -> LoginFlowResult { guard let client = session.matrixRestClient else { throw AuthenticationError.missingMXRestClient } - let pendingData = AuthenticationPendingData(homeserverAddress: client.homeserver) + let state = AuthenticationState(authenticationMode: .login, homeserverAddress: client.homeserver) let loginFlow = try await getLoginFlowResult(client: session.matrixRestClient) self.client = client - self.pendingData = pendingData + self.state = state return loginFlow } @@ -201,66 +176,18 @@ class AuthenticationService: NSObject { } } - /// Return a LoginWizard, to login to the homeserver. The login flow has to be retrieved first. - /// - /// See ``LoginWizard`` for more details - func loginWizard() throws -> LoginWizard { - if let currentLoginWizard = currentLoginWizard { - return currentLoginWizard - } - - guard let pendingData = pendingData else { - throw AuthenticationError.loginFlowNotCalled - } - - let wizard = LoginWizard() - return wizard - } - - /// Return a RegistrationWizard, to create a matrix account on the homeserver. The login flow has to be retrieved first. - /// - /// See ``RegistrationWizard`` for more details. - func registrationWizard() throws -> RegistrationWizard { - if let currentRegistrationWizard = currentRegistrationWizard { - return currentRegistrationWizard - } - - guard let pendingData = pendingData else { - throw AuthenticationError.loginFlowNotCalled - } - - - let wizard = RegistrationWizard(client: client, pendingData: pendingData) - currentRegistrationWizard = wizard - return wizard - } - /// True when login and password has been sent with success to the homeserver var isRegistrationStarted: Bool { - currentRegistrationWizard?.isRegistrationStarted ?? false - } - - /// Cancel pending login or pending registration - func cancelPendingLoginOrRegistration() { - currentTask?.cancel() - - currentLoginWizard = nil - currentRegistrationWizard = nil - - // Keep only the homesever config - guard let pendingData = pendingData else { - // Should not happen - return - } - - self.pendingData = AuthenticationPendingData(homeserverAddress: pendingData.homeserverAddress) + registrationWizard?.isRegistrationStarted ?? false } - /// Reset all pending settings, including current HomeServerConnectionConfig + /// Reset the service to a fresh state. func reset() { - pendingData = nil - currentRegistrationWizard = nil - currentLoginWizard = nil + loginWizard = nil + registrationWizard = nil + + // The previously used homeserver is re-used as `startFlow` will be called again a replace it anyway. + self.state = AuthenticationState(authenticationMode: .login, homeserverAddress: state.homeserver.address) } /// Create a session after a SSO successful login @@ -291,7 +218,7 @@ class AuthenticationService: NSObject { // MARK: - Private - private func getLoginFlowResult(client: MXRestClient/*, versions: Versions*/) async throws -> LoginFlowResult { + private func getLoginFlowResult(client: MXRestClient) async throws -> LoginFlowResult { // Get the login flow let loginFlowResponse = try await client.getLoginSession() diff --git a/Riot/Modules/Onboarding/AuthenticationCoordinatorState.swift b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationState.swift similarity index 72% rename from Riot/Modules/Onboarding/AuthenticationCoordinatorState.swift rename to RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationState.swift index c5bca72a12..186e3055b3 100644 --- a/Riot/Modules/Onboarding/AuthenticationCoordinatorState.swift +++ b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationState.swift @@ -18,22 +18,23 @@ import Foundation import MatrixSDK @available(iOS 14.0, *) -struct AuthenticationCoordinatorState { - // MARK: User choices +struct AuthenticationState { // var serverType: ServerType = .unknown - // var signMode: SignMode = .unknown - var resetPasswordEmail: String? + var authenticationMode: AuthenticationMode /// Information about the currently selected homeserver. - var selectedHomeserver: SelectedHomeserver + var homeserver: Homeserver + var isForceLoginFallbackEnabled = false - /// For SSO session recovery - var deviceId: String? + /// The registration flow response returned when calling `startFlow` for `.registration`. + var initialRegistrationFlow: RegistrationResult? - var knownCustomHomeServersUrls = [String]() - var isForceLoginFallbackEnabled = false + init(authenticationMode: AuthenticationMode, homeserverAddress: String) { + self.authenticationMode = authenticationMode + self.homeserver = Homeserver(address: homeserverAddress) + } - struct SelectedHomeserver { + struct Homeserver { /// The homeserver address as returned by the server. var address: String /// The homeserver address as input by the user (it can differ to the well-known request). @@ -44,5 +45,4 @@ struct AuthenticationCoordinatorState { /// Supported types for the login. var loginModeSupportedTypes = [MXLoginFlow]() } - } diff --git a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/LoginWizard.swift b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/LoginWizard.swift index 40fe8098aa..48bb7e99fc 100644 --- a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/LoginWizard.swift +++ b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/LoginWizard.swift @@ -17,5 +17,15 @@ import Foundation class LoginWizard { + struct State { + /// For SSO session recovery + var deviceId: String? + var resetPasswordEmail: String? + // var resetPasswordData: ResetPasswordData? + + var clientSecret = UUID().uuidString + var sendAttempt: UInt = 0 + } + // TODO } diff --git a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/RegistrationWizard.swift b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/RegistrationWizard.swift index bf7af84203..c3e714ecaf 100644 --- a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/RegistrationWizard.swift +++ b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/RegistrationWizard.swift @@ -27,33 +27,44 @@ import Foundation /// More documentation can be found in the file https://github.com/vector-im/element-android/blob/main/docs/signup.md /// and https://matrix.org/docs/spec/client_server/latest#account-registration-and-management class RegistrationWizard { + struct State { + var currentSession: String? + var isRegistrationStarted = false + var currentThreePIDData: ThreePIDData? + + var clientSecret = UUID().uuidString + var sendAttempt: UInt = 0 + } + let client: MXRestClient let sessionCreator: SessionCreator - let pendingData: AuthenticationPendingData + + private(set) var state: State /// This is the current ThreePID, waiting for validation. The SDK will store it in database, so it can be /// restored even if the app has been killed during the registration var currentThreePID: String? { - guard let threePid = pendingData.currentThreePIDData?.threePID else { return nil } + guard let threePid = state.currentThreePIDData?.threePID else { return nil } switch threePid { case .email(let string): return string case .msisdn(let msisdn, _): - return pendingData.currentThreePIDData?.registrationResponse.formattedMSISDN ?? msisdn + return state.currentThreePIDData?.registrationResponse.formattedMSISDN ?? msisdn } } /// True when login and password have been sent with success to the homeserver, /// i.e. `createAccount` has been called successfully. var isRegistrationStarted: Bool { - pendingData.isRegistrationStarted + state.isRegistrationStarted } - init(client: MXRestClient, sessionCreator: SessionCreator = SessionCreator(), pendingData: AuthenticationPendingData) { + init(client: MXRestClient, sessionCreator: SessionCreator = SessionCreator()) { self.client = client self.sessionCreator = sessionCreator - self.pendingData = pendingData + + self.state = State() } /// Call this method to get the possible registration flow of the current homeserver. @@ -84,7 +95,7 @@ class RegistrationWizard { initialDeviceDisplayName: String?) async throws -> RegistrationResult { let parameters = RegistrationParameters(username: username, password: password, initialDeviceDisplayName: initialDeviceDisplayName) let result = try await performRegistrationRequest(parameters: parameters) - pendingData.isRegistrationStarted = true + state.isRegistrationStarted = true return result } @@ -92,7 +103,7 @@ class RegistrationWizard { /// /// - Parameter response: The response from ReCaptcha func performReCaptcha(response: String) async throws -> RegistrationResult { - guard let session = pendingData.currentSession else { + guard let session = state.currentSession else { throw RegistrationError.createAccountNotCalled } @@ -102,7 +113,7 @@ class RegistrationWizard { /// Perform the "m.login.terms" stage. func acceptTerms() async throws -> RegistrationResult { - guard let session = pendingData.currentSession else { + guard let session = state.currentSession else { throw RegistrationError.createAccountNotCalled } @@ -112,7 +123,7 @@ class RegistrationWizard { /// Perform the "m.login.dummy" stage. func dummy() async throws -> RegistrationResult { - guard let session = pendingData.currentSession else { + guard let session = state.currentSession else { throw RegistrationError.createAccountNotCalled } @@ -125,13 +136,13 @@ class RegistrationWizard { /// - Parameter threePID the threePID to add to the account. If this is an email, the homeserver will send an email /// to validate it. For a msisdn a SMS will be sent. func addThreePID(threePID: RegisterThreePID) async throws -> RegistrationResult { - pendingData.currentThreePIDData = nil + state.currentThreePIDData = nil return try await sendThreePID(threePID: threePID) } /// Ask the homeserver to send again the current threePID (email or msisdn). func sendAgainThreePID() async throws -> RegistrationResult { - guard let threePID = pendingData.currentThreePIDData?.threePID else { + guard let threePID = state.currentThreePIDData?.threePID else { throw RegistrationError.createAccountNotCalled } return try await sendThreePID(threePID: threePID) @@ -148,8 +159,8 @@ class RegistrationWizard { /// - Parameter delay How long to wait before sending the request. func checkIfEmailHasBeenValidated(delay: TimeInterval) async throws -> RegistrationResult { MXLog.failure("The delay on this method is no longer available. Move this to the object handling the polling.") - guard let parameters = pendingData.currentThreePIDData?.registrationParameters else { - throw RegistrationError.noPendingThreePID + guard let parameters = state.currentThreePIDData?.registrationParameters else { + throw RegistrationError.missingThreePIDData } return try await performRegistrationRequest(parameters: parameters) @@ -158,8 +169,8 @@ class RegistrationWizard { // MARK: - Private private func validateThreePid(code: String) async throws -> RegistrationResult { - guard let threePIDData = pendingData.currentThreePIDData else { - throw RegistrationError.noPendingThreePID + guard let threePIDData = state.currentThreePIDData else { + throw RegistrationError.missingThreePIDData } guard let submitURL = threePIDData.registrationResponse.submitURL else { @@ -167,7 +178,7 @@ class RegistrationWizard { } - let validationBody = ThreePIDValidationCodeBody(clientSecret: pendingData.clientSecret, + let validationBody = ThreePIDValidationCodeBody(clientSecret: state.clientSecret, sessionID: threePIDData.registrationResponse.sessionID, code: code) @@ -185,17 +196,17 @@ class RegistrationWizard { } private func sendThreePID(threePID: RegisterThreePID) async throws -> RegistrationResult { - guard let session = pendingData.currentSession else { + guard let session = state.currentSession else { throw RegistrationError.createAccountNotCalled } let response = try await client.requestTokenDuringRegistration(for: threePID, - clientSecret: pendingData.clientSecret, - sendAttempt: pendingData.sendAttempt) + clientSecret: state.clientSecret, + sendAttempt: state.sendAttempt) - pendingData.sendAttempt += 1 + state.sendAttempt += 1 - let threePIDCredentials = ThreePIDCredentials(clientSecret: pendingData.clientSecret, sessionID: response.sessionID) + let threePIDCredentials = ThreePIDCredentials(clientSecret: state.clientSecret, sessionID: response.sessionID) let authenticationParameters: AuthenticationParameters switch threePID { case .email: @@ -206,7 +217,7 @@ class RegistrationWizard { let parameters = RegistrationParameters(auth: authenticationParameters) - pendingData.currentThreePIDData = ThreePIDData(threePID: threePID, registrationResponse: response, registrationParameters: parameters) + state.currentThreePIDData = ThreePIDData(threePID: threePID, registrationResponse: response, registrationParameters: parameters) // Send the session id for the first time return try await performRegistrationRequest(parameters: parameters) @@ -225,7 +236,7 @@ class RegistrationWizard { let authenticationSession = MXAuthenticationSession(fromJSON: jsonResponse) else { throw error } - pendingData.currentSession = authenticationSession.session + state.currentSession = authenticationSession.session return .flowResponse(authenticationSession.flowResult) } } diff --git a/RiotTests/AuthenticationServiceTests.swift b/RiotTests/AuthenticationServiceTests.swift new file mode 100644 index 0000000000..5496c09b2a --- /dev/null +++ b/RiotTests/AuthenticationServiceTests.swift @@ -0,0 +1,83 @@ +// +// Copyright 2022 New Vector Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import XCTest + +@testable import Riot + +@available(iOS 14.0, *) +class AuthenticationServiceTests: XCTestCase { + func testRegistrationWizardWhenStartingLoginFlow() async throws { + // Given a fresh service. + let service = AuthenticationService() + XCTAssertNil(service.registrationWizard, "A new service shouldn't have a registration wizard.") + + // When starting a new login flow. + try await service.startFlow(for: "https://matrix.org", as: .login) + + // Then a registration wizard shouldn't have been created. + XCTAssertNil(service.registrationWizard, "The registration wizard should not exist if startFlow was called for login.") + } + + func testRegistrationWizard() async throws { + // Given a fresh service. + let service = AuthenticationService() + XCTAssertNil(service.registrationWizard, "A new service shouldn't provide a registration wizard.") + XCTAssertNil(service.state.initialRegistrationFlow, "A new service shouldn't provide an initial registration flow.") + + // When starting a new registration flow. + try await service.startFlow(for: "https://matrix.org", as: .registration) + + // Then a registration wizard should be available for use. + XCTAssertNotNil(service.registrationWizard, "The registration wizard should exist after starting a registration flow.") + XCTAssertNotNil(service.state.initialRegistrationFlow, "The result from setting up a registration wizard should be available in the service.") + } + + func testReset() async throws { + // Given a service that has begun registration. + let service = AuthenticationService() + try await service.startFlow(for: "https://matrix.org", as: .registration) + _ = try await service.registrationWizard?.createAccount(username: UUID().uuidString, password: UUID().uuidString, initialDeviceDisplayName: "Test") + XCTAssertNotNil(service.loginWizard, "The login wizard should exist after startFlow has been called.") + XCTAssertNotNil(service.registrationWizard, "The registration wizard should exist after starting a registration flow.") + XCTAssertNotNil(service.state.initialRegistrationFlow, "An initial registration flow should exist after starting a registration flow.") + XCTAssertTrue(service.isRegistrationStarted, "The service should show as having started registration.") + + // When resetting the service. + service.reset() + + // Then the wizards should no longer exist. + XCTAssertNil(service.loginWizard, "The login wizard should be cleared after calling reset.") + XCTAssertNil(service.registrationWizard, "The registration wizard should be cleared after calling reset.") + XCTAssertNil(service.state.initialRegistrationFlow, "The initial registration flow should be cleared when calling reset.") + XCTAssertFalse(service.isRegistrationStarted, "The service should not indicate it has started registration after calling reset.") + } + + func testHomeserverState() async throws { + // Given a service that has begun login for one homeserver. + let service = AuthenticationService() + try await service.startFlow(for: "https://glasgow.social", as: .login) + XCTAssertEqual(service.state.homeserver.addressFromUser, "https://glasgow.social", "The initial address entered by the user should be stored.") + XCTAssertEqual(service.state.homeserver.address, "https://matrix.glasgow.social", "The initial address discovered from the well-known should be stored.") + + // When switching to a different homeserver + try await service.startFlow(for: "https://matrix.org", as: .login) + + // The the homeserver state should update to represent the new server + XCTAssertEqual(service.state.homeserver.addressFromUser, "https://matrix.org", "The new address entered by the user should be stored.") + XCTAssertEqual(service.state.homeserver.address, "https://matrix-client.matrix.org", "The new address discovered from the well-known should be stored.") + } +} From 60c5e2430bc7d0547035eee1b47f1299efd20af8 Mon Sep 17 00:00:00 2001 From: Doug Date: Wed, 27 Apr 2022 11:19:10 +0100 Subject: [PATCH 4/5] Tidy up AuthenticationService. --- .../Common/AuthenticationModels.swift | 4 +- .../MatrixSDK/AuthenticationService.swift | 132 +++++++++--------- .../MatrixSDK/AuthenticationState.swift | 6 +- RiotTests/AuthenticationServiceTests.swift | 12 +- 4 files changed, 78 insertions(+), 76 deletions(-) diff --git a/RiotSwiftUI/Modules/Authentication/Common/AuthenticationModels.swift b/RiotSwiftUI/Modules/Authentication/Common/AuthenticationModels.swift index 1900f9f245..d5b769da27 100644 --- a/RiotSwiftUI/Modules/Authentication/Common/AuthenticationModels.swift +++ b/RiotSwiftUI/Modules/Authentication/Common/AuthenticationModels.swift @@ -16,8 +16,8 @@ import Foundation -/// A value that dictates the authentication flow that will be used. -enum AuthenticationMode { +/// A value that represents the type of authentication flow being used. +enum AuthenticationFlow { case login case registration } diff --git a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationService.swift b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationService.swift index 15552b3217..ba5a038151 100644 --- a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationService.swift +++ b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationService.swift @@ -50,12 +50,12 @@ class AuthenticationService: NSObject { override init() { if let homeserverURL = URL(string: RiotSettings.shared.homeserverUrlString) { // Use the same homeserver that was last used. - state = AuthenticationState(authenticationMode: .login, homeserverAddress: RiotSettings.shared.homeserverUrlString) + state = AuthenticationState(flow: .login, homeserverAddress: RiotSettings.shared.homeserverUrlString) client = MXRestClient(homeServer: homeserverURL, unrecognizedCertificateHandler: nil) } else if let homeserverURL = URL(string: BuildSettings.serverConfigDefaultHomeserverUrlString) { // Fall back to the default homeserver if the stored one is invalid. - state = AuthenticationState(authenticationMode: .login, homeserverAddress: BuildSettings.serverConfigDefaultHomeserverUrlString) + state = AuthenticationState(flow: .login, homeserverAddress: BuildSettings.serverConfigDefaultHomeserverUrlString) client = MXRestClient(homeServer: homeserverURL, unrecognizedCertificateHandler: nil) } else { @@ -66,32 +66,6 @@ class AuthenticationService: NSObject { super.init() } - // MARK: - Android OnboardingViewModel - - func startFlow(for homeserverAddress: String, as authenticationMode: AuthenticationMode) async throws { - reset() - - let loginFlows = try await loginFlow(for: homeserverAddress) - - // Valid Homeserver, add it to the history. - // Note: we add what the user has input, as the data can contain a different value. - RiotSettings.shared.homeserverUrlString = homeserverAddress - - state.homeserver = .init(address: loginFlows.homeserverAddress, - addressFromUser: homeserverAddress, - preferredLoginMode: loginFlows.loginMode, - loginModeSupportedTypes: loginFlows.supportedLoginTypes) - - let loginWizard = LoginWizard() - self.loginWizard = loginWizard - - if authenticationMode == .registration { - let registrationWizard = RegistrationWizard(client: client) - state.initialRegistrationFlow = try await registrationWizard.registrationFlow() - self.registrationWizard = registrationWizard - } - } - // MARK: - Public /// Whether authentication is needed by checking for any accounts. @@ -118,47 +92,30 @@ class AuthenticationService: NSObject { MXKAccountManager.shared().activeAccounts?.first?.mxSession } - /// Request the supported login flows for this homeserver. - /// This is the first method to call to be able to get a wizard to login or to create an account - /// - Parameter homeserverAddress: The homeserver string entered by the user. - private func loginFlow(for homeserverAddress: String) async throws -> LoginFlowResult { - let homeserverAddress = HomeserverAddress.sanitized(homeserverAddress) - - guard var homeserverURL = URL(string: homeserverAddress) else { - throw AuthenticationError.invalidHomeserver - } - - let state = AuthenticationState(authenticationMode: .login, homeserverAddress: homeserverAddress) - - if let wellKnown = try? await wellKnown(for: homeserverURL), - let baseURL = URL(string: wellKnown.homeServer.baseUrl) { - homeserverURL = baseURL - } - - #warning("Add an unrecognized certificate handler.") - let client = MXRestClient(homeServer: homeserverURL, unrecognizedCertificateHandler: nil) + func startFlow(_ flow: AuthenticationFlow, for homeserverAddress: String) async throws { + reset() - let loginFlow = try await getLoginFlowResult(client: client) + let loginFlows = try await loginFlow(for: homeserverAddress) - self.client = client - self.state = state + // Valid Homeserver, add it to the history. + // Note: we add what the user has input, as the data can contain a different value. + RiotSettings.shared.homeserverUrlString = homeserverAddress - return loginFlow - } - - /// Request the supported login flows for the corresponding session. - /// This method is used to get the flows for a server after a soft-logout. - /// - Parameter session: The MXSession where a soft-logout has occurred. - private func loginFlow(for session: MXSession) async throws -> LoginFlowResult { - guard let client = session.matrixRestClient else { throw AuthenticationError.missingMXRestClient } - let state = AuthenticationState(authenticationMode: .login, homeserverAddress: client.homeserver) + state.homeserver = .init(address: loginFlows.homeserverAddress, + addressFromUser: homeserverAddress, + preferredLoginMode: loginFlows.loginMode, + loginModeSupportedTypes: loginFlows.supportedLoginTypes) - let loginFlow = try await getLoginFlowResult(client: session.matrixRestClient) + let loginWizard = LoginWizard() + self.loginWizard = loginWizard - self.client = client - self.state = state + if flow == .registration { + let registrationWizard = RegistrationWizard(client: client) + state.initialRegistrationFlow = try await registrationWizard.registrationFlow() + self.registrationWizard = registrationWizard + } - return loginFlow + state.flow = flow } /// Get a SSO url @@ -167,8 +124,8 @@ class AuthenticationService: NSObject { } /// Get the sign in or sign up fallback URL - func fallbackURL(for authenticationMode: AuthenticationMode) -> URL { - switch authenticationMode { + func fallbackURL(for flow: AuthenticationFlow) -> URL { + switch flow { case .login: return client.loginFallbackURL case .registration: @@ -187,7 +144,7 @@ class AuthenticationService: NSObject { registrationWizard = nil // The previously used homeserver is re-used as `startFlow` will be called again a replace it anyway. - self.state = AuthenticationState(authenticationMode: .login, homeserverAddress: state.homeserver.address) + self.state = AuthenticationState(flow: .login, homeserverAddress: state.homeserver.address) } /// Create a session after a SSO successful login @@ -218,6 +175,49 @@ class AuthenticationService: NSObject { // MARK: - Private + /// Request the supported login flows for this homeserver. + /// This is the first method to call to be able to get a wizard to login or to create an account + /// - Parameter homeserverAddress: The homeserver string entered by the user. + private func loginFlow(for homeserverAddress: String) async throws -> LoginFlowResult { + let homeserverAddress = HomeserverAddress.sanitized(homeserverAddress) + + guard var homeserverURL = URL(string: homeserverAddress) else { + throw AuthenticationError.invalidHomeserver + } + + let state = AuthenticationState(flow: .login, homeserverAddress: homeserverAddress) + + if let wellKnown = try? await wellKnown(for: homeserverURL), + let baseURL = URL(string: wellKnown.homeServer.baseUrl) { + homeserverURL = baseURL + } + + #warning("Add an unrecognized certificate handler.") + let client = MXRestClient(homeServer: homeserverURL, unrecognizedCertificateHandler: nil) + + let loginFlow = try await getLoginFlowResult(client: client) + + self.client = client + self.state = state + + return loginFlow + } + + /// Request the supported login flows for the corresponding session. + /// This method is used to get the flows for a server after a soft-logout. + /// - Parameter session: The MXSession where a soft-logout has occurred. + private func loginFlow(for session: MXSession) async throws -> LoginFlowResult { + guard let client = session.matrixRestClient else { throw AuthenticationError.missingMXRestClient } + let state = AuthenticationState(flow: .login, homeserverAddress: client.homeserver) + + let loginFlow = try await getLoginFlowResult(client: session.matrixRestClient) + + self.client = client + self.state = state + + return loginFlow + } + private func getLoginFlowResult(client: MXRestClient) async throws -> LoginFlowResult { // Get the login flow let loginFlowResponse = try await client.getLoginSession() diff --git a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationState.swift b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationState.swift index 186e3055b3..528090fdda 100644 --- a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationState.swift +++ b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationState.swift @@ -20,7 +20,7 @@ import MatrixSDK @available(iOS 14.0, *) struct AuthenticationState { // var serverType: ServerType = .unknown - var authenticationMode: AuthenticationMode + var flow: AuthenticationFlow /// Information about the currently selected homeserver. var homeserver: Homeserver @@ -29,8 +29,8 @@ struct AuthenticationState { /// The registration flow response returned when calling `startFlow` for `.registration`. var initialRegistrationFlow: RegistrationResult? - init(authenticationMode: AuthenticationMode, homeserverAddress: String) { - self.authenticationMode = authenticationMode + init(flow: AuthenticationFlow, homeserverAddress: String) { + self.flow = flow self.homeserver = Homeserver(address: homeserverAddress) } diff --git a/RiotTests/AuthenticationServiceTests.swift b/RiotTests/AuthenticationServiceTests.swift index 5496c09b2a..68815be2cc 100644 --- a/RiotTests/AuthenticationServiceTests.swift +++ b/RiotTests/AuthenticationServiceTests.swift @@ -26,7 +26,7 @@ class AuthenticationServiceTests: XCTestCase { XCTAssertNil(service.registrationWizard, "A new service shouldn't have a registration wizard.") // When starting a new login flow. - try await service.startFlow(for: "https://matrix.org", as: .login) + try await service.startFlow(.login, for: "https://matrix.org") // Then a registration wizard shouldn't have been created. XCTAssertNil(service.registrationWizard, "The registration wizard should not exist if startFlow was called for login.") @@ -39,7 +39,7 @@ class AuthenticationServiceTests: XCTestCase { XCTAssertNil(service.state.initialRegistrationFlow, "A new service shouldn't provide an initial registration flow.") // When starting a new registration flow. - try await service.startFlow(for: "https://matrix.org", as: .registration) + try await service.startFlow(.registration, for: "https://matrix.org") // Then a registration wizard should be available for use. XCTAssertNotNil(service.registrationWizard, "The registration wizard should exist after starting a registration flow.") @@ -49,12 +49,13 @@ class AuthenticationServiceTests: XCTestCase { func testReset() async throws { // Given a service that has begun registration. let service = AuthenticationService() - try await service.startFlow(for: "https://matrix.org", as: .registration) + try await service.startFlow(.registration, for: "https://matrix.org") _ = try await service.registrationWizard?.createAccount(username: UUID().uuidString, password: UUID().uuidString, initialDeviceDisplayName: "Test") XCTAssertNotNil(service.loginWizard, "The login wizard should exist after startFlow has been called.") XCTAssertNotNil(service.registrationWizard, "The registration wizard should exist after starting a registration flow.") XCTAssertNotNil(service.state.initialRegistrationFlow, "An initial registration flow should exist after starting a registration flow.") XCTAssertTrue(service.isRegistrationStarted, "The service should show as having started registration.") + XCTAssertEqual(service.state.flow, .registration, "The service should show as using a registration flow.") // When resetting the service. service.reset() @@ -64,17 +65,18 @@ class AuthenticationServiceTests: XCTestCase { XCTAssertNil(service.registrationWizard, "The registration wizard should be cleared after calling reset.") XCTAssertNil(service.state.initialRegistrationFlow, "The initial registration flow should be cleared when calling reset.") XCTAssertFalse(service.isRegistrationStarted, "The service should not indicate it has started registration after calling reset.") + XCTAssertEqual(service.state.flow, .login, "The flow should have been set back to login when calling reset.") } func testHomeserverState() async throws { // Given a service that has begun login for one homeserver. let service = AuthenticationService() - try await service.startFlow(for: "https://glasgow.social", as: .login) + try await service.startFlow(.login, for: "https://glasgow.social") XCTAssertEqual(service.state.homeserver.addressFromUser, "https://glasgow.social", "The initial address entered by the user should be stored.") XCTAssertEqual(service.state.homeserver.address, "https://matrix.glasgow.social", "The initial address discovered from the well-known should be stored.") // When switching to a different homeserver - try await service.startFlow(for: "https://matrix.org", as: .login) + try await service.startFlow(.login, for: "https://matrix.org") // The the homeserver state should update to represent the new server XCTAssertEqual(service.state.homeserver.addressFromUser, "https://matrix.org", "The new address entered by the user should be stored.") From a80c33cef306fbc4c066f8cafeb7351193b9508b Mon Sep 17 00:00:00 2001 From: Doug Date: Wed, 27 Apr 2022 11:33:43 +0100 Subject: [PATCH 5/5] Move initialRegistrationFlow into AuthenticationState.Homeserver. --- .../Service/MatrixSDK/AuthenticationService.swift | 2 +- .../Common/Service/MatrixSDK/AuthenticationState.swift | 6 +++--- RiotTests/AuthenticationServiceTests.swift | 10 +++++----- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationService.swift b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationService.swift index ba5a038151..0a09407f0b 100644 --- a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationService.swift +++ b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationService.swift @@ -111,7 +111,7 @@ class AuthenticationService: NSObject { if flow == .registration { let registrationWizard = RegistrationWizard(client: client) - state.initialRegistrationFlow = try await registrationWizard.registrationFlow() + state.homeserver.registrationFlow = try await registrationWizard.registrationFlow() self.registrationWizard = registrationWizard } diff --git a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationState.swift b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationState.swift index 528090fdda..05f8dd7353 100644 --- a/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationState.swift +++ b/RiotSwiftUI/Modules/Authentication/Common/Service/MatrixSDK/AuthenticationState.swift @@ -26,9 +26,6 @@ struct AuthenticationState { var homeserver: Homeserver var isForceLoginFallbackEnabled = false - /// The registration flow response returned when calling `startFlow` for `.registration`. - var initialRegistrationFlow: RegistrationResult? - init(flow: AuthenticationFlow, homeserverAddress: String) { self.flow = flow self.homeserver = Homeserver(address: homeserverAddress) @@ -44,5 +41,8 @@ struct AuthenticationState { var preferredLoginMode: LoginMode = .unknown /// Supported types for the login. var loginModeSupportedTypes = [MXLoginFlow]() + + /// The response returned when querying the homeserver for registration flows. + var registrationFlow: RegistrationResult? } } diff --git a/RiotTests/AuthenticationServiceTests.swift b/RiotTests/AuthenticationServiceTests.swift index 68815be2cc..c5f520c76f 100644 --- a/RiotTests/AuthenticationServiceTests.swift +++ b/RiotTests/AuthenticationServiceTests.swift @@ -36,14 +36,14 @@ class AuthenticationServiceTests: XCTestCase { // Given a fresh service. let service = AuthenticationService() XCTAssertNil(service.registrationWizard, "A new service shouldn't provide a registration wizard.") - XCTAssertNil(service.state.initialRegistrationFlow, "A new service shouldn't provide an initial registration flow.") + XCTAssertNil(service.state.homeserver.registrationFlow, "A new service shouldn't provide a registration flow for the homeserver.") // When starting a new registration flow. try await service.startFlow(.registration, for: "https://matrix.org") // Then a registration wizard should be available for use. XCTAssertNotNil(service.registrationWizard, "The registration wizard should exist after starting a registration flow.") - XCTAssertNotNil(service.state.initialRegistrationFlow, "The result from setting up a registration wizard should be available in the service.") + XCTAssertNotNil(service.state.homeserver.registrationFlow, "The supported registration flow should be stored after starting a registration flow.") } func testReset() async throws { @@ -51,9 +51,9 @@ class AuthenticationServiceTests: XCTestCase { let service = AuthenticationService() try await service.startFlow(.registration, for: "https://matrix.org") _ = try await service.registrationWizard?.createAccount(username: UUID().uuidString, password: UUID().uuidString, initialDeviceDisplayName: "Test") - XCTAssertNotNil(service.loginWizard, "The login wizard should exist after startFlow has been called.") + XCTAssertNotNil(service.loginWizard, "The login wizard should exist after starting a registration flow.") XCTAssertNotNil(service.registrationWizard, "The registration wizard should exist after starting a registration flow.") - XCTAssertNotNil(service.state.initialRegistrationFlow, "An initial registration flow should exist after starting a registration flow.") + XCTAssertNotNil(service.state.homeserver.registrationFlow, "The supported registration flow should be stored after starting a registration flow.") XCTAssertTrue(service.isRegistrationStarted, "The service should show as having started registration.") XCTAssertEqual(service.state.flow, .registration, "The service should show as using a registration flow.") @@ -63,7 +63,7 @@ class AuthenticationServiceTests: XCTestCase { // Then the wizards should no longer exist. XCTAssertNil(service.loginWizard, "The login wizard should be cleared after calling reset.") XCTAssertNil(service.registrationWizard, "The registration wizard should be cleared after calling reset.") - XCTAssertNil(service.state.initialRegistrationFlow, "The initial registration flow should be cleared when calling reset.") + XCTAssertNil(service.state.homeserver.registrationFlow, "The supported registration flow should be cleared when calling reset.") XCTAssertFalse(service.isRegistrationStarted, "The service should not indicate it has started registration after calling reset.") XCTAssertEqual(service.state.flow, .login, "The flow should have been set back to login when calling reset.") }