diff --git a/Sources/GRPC/ConnectionManager.swift b/Sources/GRPC/ConnectionManager.swift index 6da8655d1..6d1b6ffe9 100644 --- a/Sources/GRPC/ConnectionManager.swift +++ b/Sources/GRPC/ConnectionManager.swift @@ -19,7 +19,7 @@ import NIO import NIOConcurrencyHelpers import NIOHTTP2 -internal class ConnectionManager { +internal final class ConnectionManager { internal enum Reconnect { case none case after(TimeInterval) @@ -203,10 +203,25 @@ internal class ConnectionManager { } } + /// The `EventLoop` that the managed connection will run on. internal let eventLoop: EventLoop + + /// A connectivity state monitor. internal let monitor: ConnectivityStateMonitor + + /// An `EventLoopFuture` provider. + private let channelProvider: ConnectionManagerChannelProvider + + /// The behavior for starting a call, i.e. how patient is the caller when asking for a + /// multiplexer. + private let callStartBehavior: CallStartBehavior.Behavior + + /// The configuration to use when backing off between connection attempts, if reconnection + /// attempts should be made at all. + private let connectionBackoff: ConnectionBackoff? + + /// A logger. internal var logger: Logger - private let configuration: ClientConnection.Configuration private let connectionID: String private var channelNumber: UInt64 @@ -233,11 +248,12 @@ internal class ConnectionManager { logger[metadataKey: MetadataKey.connectionID] = "\(self.connectionIDAndNumber)" } - // Only used for testing. - private var channelProvider: (() -> EventLoopFuture)? - internal convenience init(configuration: ClientConnection.Configuration, logger: Logger) { - self.init(configuration: configuration, logger: logger, channelProvider: nil) + self.init( + configuration: configuration, + channelProvider: ClientConnection.ChannelProvider(configuration: configuration), + logger: logger + ) } /// Create a `ConnectionManager` for testing: uses the given `channelProvider` to create channels. @@ -246,17 +262,49 @@ internal class ConnectionManager { logger: Logger, channelProvider: @escaping () -> EventLoopFuture ) -> ConnectionManager { + struct Wrapper: ConnectionManagerChannelProvider { + var callback: () -> EventLoopFuture + func makeChannel( + managedBy connectionManager: ConnectionManager, + onEventLoop eventLoop: EventLoop, + connectTimeout: TimeAmount?, + logger: Logger + ) -> EventLoopFuture { + return self.callback().hop(to: eventLoop) + } + } + return ConnectionManager( configuration: configuration, - logger: logger, - channelProvider: channelProvider + channelProvider: Wrapper(callback: channelProvider), + logger: logger ) } - private init( + private convenience init( configuration: ClientConnection.Configuration, - logger: Logger, - channelProvider: (() -> EventLoopFuture)? + channelProvider: ConnectionManagerChannelProvider, + logger: Logger + ) { + self.init( + eventLoop: configuration.eventLoopGroup.next(), + channelProvider: channelProvider, + callStartBehavior: configuration.callStartBehavior.wrapped, + connectionBackoff: configuration.connectionBackoff, + connectivityStateDelegate: configuration.connectivityStateDelegate, + connectivityStateDelegateQueue: configuration.connectivityStateDelegateQueue, + logger: logger + ) + } + + private init( + eventLoop: EventLoop, + channelProvider: ConnectionManagerChannelProvider, + callStartBehavior: CallStartBehavior.Behavior, + connectionBackoff: ConnectionBackoff?, + connectivityStateDelegate: ConnectivityStateDelegate?, + connectivityStateDelegateQueue: DispatchQueue?, + logger: Logger ) { // Setup the logger. var logger = logger @@ -264,16 +312,16 @@ internal class ConnectionManager { let channelNumber: UInt64 = 0 logger[metadataKey: MetadataKey.connectionID] = "\(connectionID)/\(channelNumber)" - let eventLoop = configuration.eventLoopGroup.next() self.eventLoop = eventLoop self.state = .idle - self.monitor = ConnectivityStateMonitor( - delegate: configuration.connectivityStateDelegate, - queue: configuration.connectivityStateDelegateQueue - ) - self.configuration = configuration self.channelProvider = channelProvider + self.callStartBehavior = callStartBehavior + self.connectionBackoff = connectionBackoff + self.monitor = ConnectivityStateMonitor( + delegate: connectivityStateDelegate, + queue: connectivityStateDelegateQueue + ) self.connectionID = connectionID self.channelNumber = channelNumber @@ -285,7 +333,7 @@ internal class ConnectionManager { /// one chance to connect - if not reconnections are managed here. internal func getHTTP2Multiplexer() -> EventLoopFuture { func getHTTP2Multiplexer0() -> EventLoopFuture { - switch self.configuration.callStartBehavior.wrapped { + switch self.callStartBehavior { case .waitsForConnectivity: return self.getHTTP2MultiplexerPatient() case .fastFailure: @@ -564,7 +612,7 @@ internal class ConnectionManager { // the channel? case let .ready(ready): // No, no backoff is configured. - if self.configuration.connectionBackoff == nil { + if self.connectionBackoff == nil { self.logger.debug("shutting down connection, no reconnect configured/remaining") self.state = .shutdown( ShutdownState( @@ -581,7 +629,7 @@ internal class ConnectionManager { self.startConnecting() } self.logger.debug("scheduling connection attempt", metadata: ["delay": "0"]) - let backoffIterator = self.configuration.connectionBackoff?.makeIterator() + let backoffIterator = self.connectionBackoff?.makeIterator() self.state = .transientFailure(TransientFailureState( from: ready, scheduled: scheduled, @@ -747,7 +795,7 @@ extension ConnectionManager { private func startConnecting() { switch self.state { case .idle: - let iterator = self.configuration.connectionBackoff?.makeIterator() + let iterator = self.connectionBackoff?.makeIterator() self.startConnecting( backoffIterator: iterator, muxPromise: self.eventLoop.makePromise() @@ -788,12 +836,17 @@ extension ConnectionManager { self.eventLoop.assertInEventLoop() let candidate: EventLoopFuture = self.eventLoop.flatSubmit { - let channel = self.makeChannel( - connectTimeout: timeoutAndBackoff?.timeout + let channel: EventLoopFuture = self.channelProvider.makeChannel( + managedBy: self, + onEventLoop: self.eventLoop, + connectTimeout: timeoutAndBackoff.map { .seconds(timeInterval: $0.timeout) }, + logger: self.logger ) + channel.whenFailure { error in self.connectionFailed(withError: error) } + return channel } @@ -820,72 +873,3 @@ extension ConnectionManager { preconditionFailure("Invalid state \(self.state) for \(function)", file: file, line: line) } } - -extension ConnectionManager { - private func makeBootstrap( - connectTimeout: TimeInterval? - ) -> ClientBootstrapProtocol { - let serverHostname: String? = self.configuration.tls.flatMap { tls -> String? in - if let hostnameOverride = tls.hostnameOverride { - return hostnameOverride - } else { - return configuration.target.host - } - }.flatMap { hostname in - if hostname.isIPAddress { - return nil - } else { - return hostname - } - } - - let bootstrap = PlatformSupport.makeClientBootstrap(group: self.eventLoop, logger: self.logger) - .channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) - .channelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1) - .channelInitializer { channel in - let initialized = channel.configureGRPCClient( - httpTargetWindowSize: self.configuration.httpTargetWindowSize, - tlsConfiguration: self.configuration.tls?.configuration, - tlsServerHostname: serverHostname, - connectionManager: self, - connectionKeepalive: self.configuration.connectionKeepalive, - connectionIdleTimeout: self.configuration.connectionIdleTimeout, - errorDelegate: self.configuration.errorDelegate, - requiresZeroLengthWriteWorkaround: PlatformSupport.requiresZeroLengthWriteWorkaround( - group: self.eventLoop, - hasTLS: self.configuration.tls != nil - ), - logger: self.logger, - customVerificationCallback: self.configuration.tls?.customVerificationCallback - ) - - // Run the debug initializer, if there is one. - if let debugInitializer = self.configuration.debugChannelInitializer { - return initialized.flatMap { - debugInitializer(channel) - } - } else { - return initialized - } - } - - if let connectTimeout = connectTimeout { - return bootstrap.connectTimeout(.seconds(timeInterval: connectTimeout)) - } else { - return bootstrap - } - } - - private func makeChannel( - connectTimeout: TimeInterval? - ) -> EventLoopFuture { - if let provider = self.channelProvider { - return provider() - } else { - let bootstrap = self.makeBootstrap( - connectTimeout: connectTimeout - ) - return bootstrap.connect(to: self.configuration.target) - } - } -} diff --git a/Sources/GRPC/ConnectionManagerChannelProvider.swift b/Sources/GRPC/ConnectionManagerChannelProvider.swift new file mode 100644 index 000000000..9c568f017 --- /dev/null +++ b/Sources/GRPC/ConnectionManagerChannelProvider.swift @@ -0,0 +1,102 @@ +/* + * Copyright 2021, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import Logging +import NIO + +internal protocol ConnectionManagerChannelProvider { + /// Make an `EventLoopFuture`. + /// + /// - Parameters: + /// - connectionManager: The `ConnectionManager` requesting the `Channel`. + /// - eventLoop: The `EventLoop` to use for the`Channel`. + /// - connectTimeout: Optional connection timeout when starting the connection. + /// - logger: A logger. + func makeChannel( + managedBy connectionManager: ConnectionManager, + onEventLoop eventLoop: EventLoop, + connectTimeout: TimeAmount?, + logger: Logger + ) -> EventLoopFuture +} + +extension ClientConnection { + internal struct ChannelProvider { + private var configuration: Configuration + + internal init(configuration: Configuration) { + self.configuration = configuration + } + } +} + +extension ClientConnection.ChannelProvider: ConnectionManagerChannelProvider { + internal func makeChannel( + managedBy connectionManager: ConnectionManager, + onEventLoop eventLoop: EventLoop, + connectTimeout: TimeAmount?, + logger: Logger + ) -> EventLoopFuture { + let serverHostname: String? = self.configuration.tls.flatMap { tls -> String? in + if let hostnameOverride = tls.hostnameOverride { + return hostnameOverride + } else { + return self.configuration.target.host + } + }.flatMap { hostname in + if hostname.isIPAddress { + return nil + } else { + return hostname + } + } + + let bootstrap = PlatformSupport.makeClientBootstrap(group: eventLoop, logger: logger) + .channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) + .channelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1) + .channelInitializer { channel in + let initialized = channel.configureGRPCClient( + httpTargetWindowSize: self.configuration.httpTargetWindowSize, + tlsConfiguration: self.configuration.tls?.configuration, + tlsServerHostname: serverHostname, + connectionManager: connectionManager, + connectionKeepalive: self.configuration.connectionKeepalive, + connectionIdleTimeout: self.configuration.connectionIdleTimeout, + errorDelegate: self.configuration.errorDelegate, + requiresZeroLengthWriteWorkaround: PlatformSupport.requiresZeroLengthWriteWorkaround( + group: eventLoop, + hasTLS: self.configuration.tls != nil + ), + logger: logger, + customVerificationCallback: self.configuration.tls?.customVerificationCallback + ) + + // Run the debug initializer, if there is one. + if let debugInitializer = self.configuration.debugChannelInitializer { + return initialized.flatMap { + debugInitializer(channel) + } + } else { + return initialized + } + } + + if let connectTimeout = connectTimeout { + _ = bootstrap.connectTimeout(connectTimeout) + } + + return bootstrap.connect(to: self.configuration.target) + } +}