Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor channel connectivity to avoid multiple spin loops #380

Merged
merged 9 commits into from
Feb 26, 2019
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 28 additions & 90 deletions Sources/SwiftGRPC/Core/Channel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,26 @@
* limitations under the License.
*/
#if SWIFT_PACKAGE
import CgRPC
import Dispatch
import CgRPC
#endif
import Foundation

/// A gRPC Channel
public class Channel {
private let mutex = Mutex()
/// Pointer to underlying C representation
private let underlyingChannel: UnsafeMutableRawPointer

/// Completion queue for channel call operations
private let completionQueue: CompletionQueue
/// Observer for connectivity state changes. Created lazily if needed
private var connectivityObserver: ConnectivityObserver?

/// Timeout for new calls
public var timeout: TimeInterval = 600.0

/// Default host to use for new calls
public var host: String

/// Connectivity state observers
private var connectivityObservers: [ConnectivityObserver] = []

/// Initializes a gRPC channel
///
/// - Parameter address: the address of the server to be called
Expand All @@ -47,12 +45,12 @@ public class Channel {
let argumentWrappers = arguments.map { $0.toCArg() }

underlyingChannel = withExtendedLifetime(argumentWrappers) {
var argumentValues = argumentWrappers.map { $0.wrapped }
if secure {
return cgrpc_channel_create_secure(address, kRootCertificates, nil, nil, &argumentValues, Int32(arguments.count))
} else {
return cgrpc_channel_create(address, &argumentValues, Int32(arguments.count))
}
var argumentValues = argumentWrappers.map { $0.wrapped }
rebello95 marked this conversation as resolved.
Show resolved Hide resolved
if secure {
return cgrpc_channel_create_secure(address, kRootCertificates, nil, nil, &argumentValues, Int32(arguments.count))
} else {
return cgrpc_channel_create(address, &argumentValues, Int32(arguments.count))
}
}
completionQueue = CompletionQueue(underlyingCompletionQueue: cgrpc_channel_completion_queue(underlyingChannel), name: "Client")
completionQueue.run() // start a loop that watches the channel's completion queue
Expand All @@ -66,10 +64,10 @@ public class Channel {
gRPC.initialize()
host = googleAddress
let argumentWrappers = arguments.map { $0.toCArg() }

underlyingChannel = withExtendedLifetime(argumentWrappers) {
var argumentValues = argumentWrappers.map { $0.wrapped }
return cgrpc_channel_create_google(googleAddress, &argumentValues, Int32(arguments.count))
var argumentValues = argumentWrappers.map { $0.wrapped }
return cgrpc_channel_create_google(googleAddress, &argumentValues, Int32(arguments.count))
}

completionQueue = CompletionQueue(underlyingCompletionQueue: cgrpc_channel_completion_queue(underlyingChannel), name: "Client")
Expand All @@ -89,17 +87,17 @@ public class Channel {
let argumentWrappers = arguments.map { $0.toCArg() }

underlyingChannel = withExtendedLifetime(argumentWrappers) {
var argumentValues = argumentWrappers.map { $0.wrapped }
return cgrpc_channel_create_secure(address, certificates, clientCertificates, clientKey, &argumentValues, Int32(arguments.count))
var argumentValues = argumentWrappers.map { $0.wrapped }
return cgrpc_channel_create_secure(address, certificates, clientCertificates, clientKey, &argumentValues, Int32(arguments.count))
}
completionQueue = CompletionQueue(underlyingCompletionQueue: cgrpc_channel_completion_queue(underlyingChannel), name: "Client")
completionQueue.run() // start a loop that watches the channel's completion queue
}

deinit {
connectivityObservers.forEach { $0.shutdown() }
cgrpc_channel_destroy(underlyingChannel)
completionQueue.shutdown()
self.connectivityObserver?.shutdown()
rebello95 marked this conversation as resolved.
Show resolved Hide resolved
cgrpc_channel_destroy(self.underlyingChannel)
self.completionQueue.shutdown()
}

/// Constructs a Call object to make a gRPC API call
Expand All @@ -109,7 +107,7 @@ public class Channel {
/// - Parameter timeout: a timeout value in seconds
/// - Returns: a Call object that can be used to perform the request
public func makeCall(_ method: String, host: String = "", timeout: TimeInterval? = nil) -> Call {
let host = (host == "") ? self.host : host
let host = host.isEmpty ? self.host : host
let timeout = timeout ?? self.timeout
let underlyingCall = cgrpc_channel_create_call(underlyingChannel, method, host, timeout)!
return Call(underlyingCall: underlyingCall, owned: true, completionQueue: completionQueue)
Expand All @@ -126,77 +124,17 @@ public class Channel {
/// Subscribe to connectivity state changes
///
/// - Parameter callback: block executed every time a new connectivity state is detected
public func subscribe(callback: @escaping (ConnectivityState) -> Void) {
connectivityObservers.append(ConnectivityObserver(underlyingChannel: underlyingChannel, currentState: connectivityState(), callback: callback))
}
}

private extension Channel {
final class ConnectivityObserver {
private let completionQueue: CompletionQueue
private let underlyingChannel: UnsafeMutableRawPointer
private let underlyingCompletionQueue: UnsafeMutableRawPointer
private let callback: (ConnectivityState) -> Void
private var lastState: ConnectivityState
private var hasBeenShutdown = false
private let stateMutex: Mutex = Mutex()

init(underlyingChannel: UnsafeMutableRawPointer, currentState: ConnectivityState, callback: @escaping (ConnectivityState) -> ()) {
self.underlyingChannel = underlyingChannel
self.underlyingCompletionQueue = cgrpc_completion_queue_create_for_next()
self.completionQueue = CompletionQueue(underlyingCompletionQueue: self.underlyingCompletionQueue, name: "Connectivity State")
self.callback = callback
self.lastState = currentState
run()
}

deinit {
shutdown()
}

private func run() {
let spinloopThreadQueue = DispatchQueue(label: "SwiftGRPC.ConnectivityObserver.run.spinloopThread")

spinloopThreadQueue.async {
while true {
guard (self.stateMutex.synchronize{ !self.hasBeenShutdown }) else {
return
}

guard let underlyingState = self.lastState.underlyingState else { return }

let deadline: TimeInterval = 0.2
cgrpc_channel_watch_connectivity_state(self.underlyingChannel, self.underlyingCompletionQueue, underlyingState, deadline, nil)
let event = self.completionQueue.wait(timeout: deadline)

guard (self.stateMutex.synchronize{ !self.hasBeenShutdown }) else {
return
}

switch event.type {
case .complete:
let newState = ConnectivityState(cgrpc_channel_check_connectivity_state(self.underlyingChannel, 0))

if newState != self.lastState {
self.callback(newState)
}
self.lastState = newState

case .queueShutdown:
return

default:
continue
}
}
public func addConnectivityObserver(callback: @escaping (ConnectivityState) -> Void) {
self.mutex.synchronize {
let observer: ConnectivityObserver
if let existingObserver = self.connectivityObserver {
observer = existingObserver
} else {
observer = ConnectivityObserver(underlyingChannel: self.underlyingChannel)
self.connectivityObserver = observer
}
}

func shutdown() {
stateMutex.synchronize {
hasBeenShutdown = true
}
completionQueue.shutdown()
observer.addConnectivityObserver(callback: callback)
}
}
}
97 changes: 97 additions & 0 deletions Sources/SwiftGRPC/Core/ChannelConnectivityObserver.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright 2016, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#if SWIFT_PACKAGE
import CgRPC
import Dispatch
#endif
import Foundation

extension Channel {
/// Provides an interface for observing the connectivity of a given channel.
final class ConnectivityObserver {
private let mutex = Mutex()
private let completionQueue: CompletionQueue
private let underlyingChannel: UnsafeMutableRawPointer
private let underlyingCompletionQueue: UnsafeMutableRawPointer
private var callbacks = [(ConnectivityState) -> Void]()
private var hasBeenShutdown = false

init(underlyingChannel: UnsafeMutableRawPointer) {
self.underlyingChannel = underlyingChannel
self.underlyingCompletionQueue = cgrpc_completion_queue_create_for_next()
self.completionQueue = CompletionQueue(underlyingCompletionQueue: self.underlyingCompletionQueue,
name: "Connectivity State")
self.run()
}

deinit {
self.shutdown()
}

func addConnectivityObserver(callback: @escaping (ConnectivityState) -> Void) {
self.mutex.synchronize {
self.callbacks.append(callback)
}
}

func shutdown() {
self.mutex.synchronize {
guard !self.hasBeenShutdown else { return }

self.hasBeenShutdown = true
self.completionQueue.shutdown()
}
}

// MARK: - Private

private func run() {
let spinloopThreadQueue = DispatchQueue(label: "SwiftGRPC.ConnectivityObserver.run.spinloopThread")
var lastState = ConnectivityState(cgrpc_channel_check_connectivity_state(self.underlyingChannel, 0))
spinloopThreadQueue.async {
while (self.mutex.synchronize { !self.hasBeenShutdown }) {
guard let underlyingState = lastState.underlyingState else { return }

let deadline: TimeInterval = 0.2
cgrpc_channel_watch_connectivity_state(self.underlyingChannel, self.underlyingCompletionQueue,
underlyingState, deadline, nil)

let event = self.completionQueue.wait(timeout: deadline)
guard (self.mutex.synchronize { !self.hasBeenShutdown }) else {
return
}

switch event.type {
case .complete:
let newState = ConnectivityState(cgrpc_channel_check_connectivity_state(self.underlyingChannel, 0))
guard newState != lastState else { continue }

lastState = newState
self.mutex.synchronize {
self.callbacks.forEach { callback in callback(newState) }
rebello95 marked this conversation as resolved.
Show resolved Hide resolved
}

case .queueShutdown:
return

default:
continue
}
}
}
}
}
}
24 changes: 21 additions & 3 deletions Tests/SwiftGRPCTests/ChannelConnectivityTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ final class ChannelConnectivityTests: BasicEchoTestCase {

static var allTests: [(String, (ChannelConnectivityTests) -> () throws -> Void)] {
return [
("testDanglingConnectivityObserversDontCrash", testDanglingConnectivityObserversDontCrash)
("testDanglingConnectivityObserversDontCrash", testDanglingConnectivityObserversDontCrash),
("testMultipleConnectivityObserversAreCalled", testMultipleConnectivityObserversAreCalled),
]
}
}
Expand All @@ -30,12 +31,12 @@ extension ChannelConnectivityTests {
func testDanglingConnectivityObserversDontCrash() {
let completionHandlerExpectation = expectation(description: "completion handler called")

client?.channel.subscribe { connectivityState in
client.channel.addConnectivityObserver { connectivityState in
print("ConnectivityState: \(connectivityState)")
}

let request = Echo_EchoRequest(text: "foo bar baz foo bar baz")
_ = try! client!.expand(request) { callResult in
_ = try! client.expand(request) { callResult in
print("callResult.statusCode: \(callResult.statusCode)")
completionHandlerExpectation.fulfill()
}
Expand All @@ -46,4 +47,21 @@ extension ChannelConnectivityTests {

waitForExpectations(timeout: 0.5)
}

func testMultipleConnectivityObserversAreCalled() {
let completionHandlerExpectation = expectation(description: "completion handler called")
var firstObserverCalled = false
var secondObserverCalled = false

client.channel.addConnectivityObserver { _ in firstObserverCalled = true }
rebello95 marked this conversation as resolved.
Show resolved Hide resolved
client.channel.addConnectivityObserver { _ in secondObserverCalled = true }

_ = try! client.expand(Echo_EchoRequest(text: "foo bar baz foo bar baz")) { _ in
completionHandlerExpectation.fulfill()
}

waitForExpectations(timeout: 0.5)
XCTAssertTrue(firstObserverCalled)
XCTAssertTrue(secondObserverCalled)
}
}