Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pull channel creation out of ConnectionManager #1158

Merged
merged 1 commit into from
Apr 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 76 additions & 92 deletions Sources/GRPC/ConnectionManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import NIO
import NIOConcurrencyHelpers
import NIOHTTP2

internal class ConnectionManager {
internal final class ConnectionManager {
internal enum Reconnect {
case none
case after(TimeInterval)
Expand Down Expand Up @@ -203,10 +203,25 @@ internal class ConnectionManager {
}
}

/// The `EventLoop` that the managed connection will run on.
internal let eventLoop: EventLoop

/// A connectivity state monitor.
internal let monitor: ConnectivityStateMonitor

/// An `EventLoopFuture<Channel>` provider.
private let channelProvider: ConnectionManagerChannelProvider

/// The behavior for starting a call, i.e. how patient is the caller when asking for a
/// multiplexer.
private let callStartBehavior: CallStartBehavior.Behavior

/// The configuration to use when backing off between connection attempts, if reconnection
/// attempts should be made at all.
private let connectionBackoff: ConnectionBackoff?

/// A logger.
internal var logger: Logger
private let configuration: ClientConnection.Configuration

private let connectionID: String
private var channelNumber: UInt64
Expand All @@ -233,11 +248,12 @@ internal class ConnectionManager {
logger[metadataKey: MetadataKey.connectionID] = "\(self.connectionIDAndNumber)"
}

// Only used for testing.
private var channelProvider: (() -> EventLoopFuture<Channel>)?

internal convenience init(configuration: ClientConnection.Configuration, logger: Logger) {
self.init(configuration: configuration, logger: logger, channelProvider: nil)
self.init(
configuration: configuration,
channelProvider: ClientConnection.ChannelProvider(configuration: configuration),
logger: logger
)
}

/// Create a `ConnectionManager` for testing: uses the given `channelProvider` to create channels.
Expand All @@ -246,34 +262,66 @@ internal class ConnectionManager {
logger: Logger,
channelProvider: @escaping () -> EventLoopFuture<Channel>
) -> ConnectionManager {
struct Wrapper: ConnectionManagerChannelProvider {
var callback: () -> EventLoopFuture<Channel>
func makeChannel(
managedBy connectionManager: ConnectionManager,
onEventLoop eventLoop: EventLoop,
connectTimeout: TimeAmount?,
logger: Logger
) -> EventLoopFuture<Channel> {
return self.callback().hop(to: eventLoop)
}
}

return ConnectionManager(
configuration: configuration,
logger: logger,
channelProvider: channelProvider
channelProvider: Wrapper(callback: channelProvider),
logger: logger
)
}

private init(
private convenience init(
configuration: ClientConnection.Configuration,
logger: Logger,
channelProvider: (() -> EventLoopFuture<Channel>)?
channelProvider: ConnectionManagerChannelProvider,
logger: Logger
) {
self.init(
eventLoop: configuration.eventLoopGroup.next(),
channelProvider: channelProvider,
callStartBehavior: configuration.callStartBehavior.wrapped,
connectionBackoff: configuration.connectionBackoff,
connectivityStateDelegate: configuration.connectivityStateDelegate,
connectivityStateDelegateQueue: configuration.connectivityStateDelegateQueue,
logger: logger
)
}

private init(
eventLoop: EventLoop,
channelProvider: ConnectionManagerChannelProvider,
callStartBehavior: CallStartBehavior.Behavior,
connectionBackoff: ConnectionBackoff?,
connectivityStateDelegate: ConnectivityStateDelegate?,
connectivityStateDelegateQueue: DispatchQueue?,
logger: Logger
) {
// Setup the logger.
var logger = logger
let connectionID = UUID().uuidString
let channelNumber: UInt64 = 0
logger[metadataKey: MetadataKey.connectionID] = "\(connectionID)/\(channelNumber)"

let eventLoop = configuration.eventLoopGroup.next()
self.eventLoop = eventLoop
self.state = .idle
self.monitor = ConnectivityStateMonitor(
delegate: configuration.connectivityStateDelegate,
queue: configuration.connectivityStateDelegateQueue
)
self.configuration = configuration

self.channelProvider = channelProvider
self.callStartBehavior = callStartBehavior
self.connectionBackoff = connectionBackoff
self.monitor = ConnectivityStateMonitor(
delegate: connectivityStateDelegate,
queue: connectivityStateDelegateQueue
)

self.connectionID = connectionID
self.channelNumber = channelNumber
Expand All @@ -285,7 +333,7 @@ internal class ConnectionManager {
/// one chance to connect - if not reconnections are managed here.
internal func getHTTP2Multiplexer() -> EventLoopFuture<HTTP2StreamMultiplexer> {
func getHTTP2Multiplexer0() -> EventLoopFuture<HTTP2StreamMultiplexer> {
switch self.configuration.callStartBehavior.wrapped {
switch self.callStartBehavior {
case .waitsForConnectivity:
return self.getHTTP2MultiplexerPatient()
case .fastFailure:
Expand Down Expand Up @@ -564,7 +612,7 @@ internal class ConnectionManager {
// the channel?
case let .ready(ready):
// No, no backoff is configured.
if self.configuration.connectionBackoff == nil {
if self.connectionBackoff == nil {
self.logger.debug("shutting down connection, no reconnect configured/remaining")
self.state = .shutdown(
ShutdownState(
Expand All @@ -581,7 +629,7 @@ internal class ConnectionManager {
self.startConnecting()
}
self.logger.debug("scheduling connection attempt", metadata: ["delay": "0"])
let backoffIterator = self.configuration.connectionBackoff?.makeIterator()
let backoffIterator = self.connectionBackoff?.makeIterator()
self.state = .transientFailure(TransientFailureState(
from: ready,
scheduled: scheduled,
Expand Down Expand Up @@ -747,7 +795,7 @@ extension ConnectionManager {
private func startConnecting() {
switch self.state {
case .idle:
let iterator = self.configuration.connectionBackoff?.makeIterator()
let iterator = self.connectionBackoff?.makeIterator()
self.startConnecting(
backoffIterator: iterator,
muxPromise: self.eventLoop.makePromise()
Expand Down Expand Up @@ -788,12 +836,17 @@ extension ConnectionManager {
self.eventLoop.assertInEventLoop()

let candidate: EventLoopFuture<Channel> = self.eventLoop.flatSubmit {
let channel = self.makeChannel(
connectTimeout: timeoutAndBackoff?.timeout
let channel: EventLoopFuture<Channel> = self.channelProvider.makeChannel(
managedBy: self,
onEventLoop: self.eventLoop,
connectTimeout: timeoutAndBackoff.map { .seconds(timeInterval: $0.timeout) },
logger: self.logger
)

channel.whenFailure { error in
self.connectionFailed(withError: error)
}

return channel
}

Expand All @@ -820,72 +873,3 @@ extension ConnectionManager {
preconditionFailure("Invalid state \(self.state) for \(function)", file: file, line: line)
}
}

extension ConnectionManager {
private func makeBootstrap(
connectTimeout: TimeInterval?
) -> ClientBootstrapProtocol {
let serverHostname: String? = self.configuration.tls.flatMap { tls -> String? in
if let hostnameOverride = tls.hostnameOverride {
return hostnameOverride
} else {
return configuration.target.host
}
}.flatMap { hostname in
if hostname.isIPAddress {
return nil
} else {
return hostname
}
}

let bootstrap = PlatformSupport.makeClientBootstrap(group: self.eventLoop, logger: self.logger)
.channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
.channelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1)
.channelInitializer { channel in
let initialized = channel.configureGRPCClient(
httpTargetWindowSize: self.configuration.httpTargetWindowSize,
tlsConfiguration: self.configuration.tls?.configuration,
tlsServerHostname: serverHostname,
connectionManager: self,
connectionKeepalive: self.configuration.connectionKeepalive,
connectionIdleTimeout: self.configuration.connectionIdleTimeout,
errorDelegate: self.configuration.errorDelegate,
requiresZeroLengthWriteWorkaround: PlatformSupport.requiresZeroLengthWriteWorkaround(
group: self.eventLoop,
hasTLS: self.configuration.tls != nil
),
logger: self.logger,
customVerificationCallback: self.configuration.tls?.customVerificationCallback
)

// Run the debug initializer, if there is one.
if let debugInitializer = self.configuration.debugChannelInitializer {
return initialized.flatMap {
debugInitializer(channel)
}
} else {
return initialized
}
}

if let connectTimeout = connectTimeout {
return bootstrap.connectTimeout(.seconds(timeInterval: connectTimeout))
} else {
return bootstrap
}
}

private func makeChannel(
connectTimeout: TimeInterval?
) -> EventLoopFuture<Channel> {
if let provider = self.channelProvider {
return provider()
} else {
let bootstrap = self.makeBootstrap(
connectTimeout: connectTimeout
)
return bootstrap.connect(to: self.configuration.target)
}
}
}
102 changes: 102 additions & 0 deletions Sources/GRPC/ConnectionManagerChannelProvider.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Copyright 2021, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import Logging
import NIO

internal protocol ConnectionManagerChannelProvider {
/// Make an `EventLoopFuture<Channel>`.
///
/// - Parameters:
/// - connectionManager: The `ConnectionManager` requesting the `Channel`.
/// - eventLoop: The `EventLoop` to use for the`Channel`.
/// - connectTimeout: Optional connection timeout when starting the connection.
/// - logger: A logger.
func makeChannel(
managedBy connectionManager: ConnectionManager,
onEventLoop eventLoop: EventLoop,
connectTimeout: TimeAmount?,
logger: Logger
) -> EventLoopFuture<Channel>
}

extension ClientConnection {
internal struct ChannelProvider {
private var configuration: Configuration

internal init(configuration: Configuration) {
self.configuration = configuration
}
}
}

extension ClientConnection.ChannelProvider: ConnectionManagerChannelProvider {
internal func makeChannel(
managedBy connectionManager: ConnectionManager,
onEventLoop eventLoop: EventLoop,
connectTimeout: TimeAmount?,
logger: Logger
) -> EventLoopFuture<Channel> {
let serverHostname: String? = self.configuration.tls.flatMap { tls -> String? in
if let hostnameOverride = tls.hostnameOverride {
return hostnameOverride
} else {
return self.configuration.target.host
}
}.flatMap { hostname in
if hostname.isIPAddress {
return nil
} else {
return hostname
}
}

let bootstrap = PlatformSupport.makeClientBootstrap(group: eventLoop, logger: logger)
.channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
.channelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1)
.channelInitializer { channel in
let initialized = channel.configureGRPCClient(
httpTargetWindowSize: self.configuration.httpTargetWindowSize,
tlsConfiguration: self.configuration.tls?.configuration,
tlsServerHostname: serverHostname,
connectionManager: connectionManager,
connectionKeepalive: self.configuration.connectionKeepalive,
connectionIdleTimeout: self.configuration.connectionIdleTimeout,
errorDelegate: self.configuration.errorDelegate,
requiresZeroLengthWriteWorkaround: PlatformSupport.requiresZeroLengthWriteWorkaround(
group: eventLoop,
hasTLS: self.configuration.tls != nil
),
logger: logger,
customVerificationCallback: self.configuration.tls?.customVerificationCallback
)

// Run the debug initializer, if there is one.
if let debugInitializer = self.configuration.debugChannelInitializer {
return initialized.flatMap {
debugInitializer(channel)
}
} else {
return initialized
}
}

if let connectTimeout = connectTimeout {
_ = bootstrap.connectTimeout(connectTimeout)
}

return bootstrap.connect(to: self.configuration.target)
}
}