diff --git a/Examples/SimpleXcode/Simple/Document.swift b/Examples/SimpleXcode/Simple/Document.swift index 7bf8de8bf..aeded6892 100644 --- a/Examples/SimpleXcode/Simple/Document.swift +++ b/Examples/SimpleXcode/Simple/Document.swift @@ -144,16 +144,17 @@ class Document: NSDocument { if !self.isRunning() { break } - let method = (i < steps) ? "/hello" : "/quit" - let call = self.channel.makeCall(method) - - let metadata = try! Metadata([ - "x": "xylophone", - "y": "yu", - "z": "zither" - ]) do { + let method = (i < steps) ? "/hello" : "/quit" + let call = try self.channel.makeCall(method) + + let metadata = try Metadata([ + "x": "xylophone", + "y": "yu", + "z": "zither" + ]) + try call.start(.unary, metadata: metadata, message: messageData) { callResult in diff --git a/Sources/Examples/Simple/main.swift b/Sources/Examples/Simple/main.swift index 2840f1729..e935e0b6b 100644 --- a/Sources/Examples/Simple/main.swift +++ b/Sources/Examples/Simple/main.swift @@ -30,7 +30,7 @@ func client() throws { let method = (i < steps - 1) ? "/hello" : "/quit" print("calling " + method) - let call = c.makeCall(method) + let call = try c.makeCall(method) let metadata = try Metadata([ "x": "xylophone", @@ -38,7 +38,7 @@ func client() throws { "z": "zither" ]) - try! call.start(.unary, metadata: metadata, message: message) { + try call.start(.unary, metadata: metadata, message: message) { response in print("status:", response.statusCode) print("statusMessage:", response.statusMessage!) diff --git a/Sources/SwiftGRPC/Core/Channel.swift b/Sources/SwiftGRPC/Core/Channel.swift index 4fffdf704..baa3417ff 100644 --- a/Sources/SwiftGRPC/Core/Channel.swift +++ b/Sources/SwiftGRPC/Core/Channel.swift @@ -18,6 +18,16 @@ import CgRPC #endif import Foundation +/// Used to hold weak references to objects since `NSHashTable.weakObjects()` isn't available on Linux. +/// If/when this type becomes available on Linux, this should be replaced. +private final class WeakReference { + private(set) weak var value: T? + + init(value: T) { + self.value = value + } +} + /// A gRPC Channel public class Channel { private let mutex = Mutex() @@ -25,8 +35,12 @@ public class Channel { private let underlyingChannel: UnsafeMutableRawPointer /// Completion queue for channel call operations private let completionQueue: CompletionQueue + /// Weak references to API calls using this channel that are in-flight + private var activeCalls = [WeakReference]() /// Observer for connectivity state changes. Created lazily if needed private var connectivityObserver: ConnectivityObserver? + /// Whether the gRPC channel has been shut down + private var hasBeenShutdown = false /// Timeout for new calls public var timeout: TimeInterval = 600.0 @@ -34,44 +48,45 @@ public class Channel { /// Default host to use for new calls public var host: String + /// Errors that may be thrown by the channel + enum Error: Swift.Error { + /// Action cannot be performed because the channel has already been shut down + case alreadyShutdown + /// Failed to create a new call within the gRPC stack + case callCreationFailed + } + /// Initializes a gRPC channel /// /// - Parameter address: the address of the server to be called /// - Parameter secure: if true, use TLS /// - Parameter arguments: list of channel configuration options - public init(address: String, secure: Bool = true, arguments: [Argument] = []) { + public convenience init(address: String, secure: Bool = true, arguments: [Argument] = []) { gRPC.initialize() - host = address - let argumentWrappers = arguments.map { $0.toCArg() } - underlyingChannel = withExtendedLifetime(argumentWrappers) { + let argumentWrappers = arguments.map { $0.toCArg() } + self.init(host: address, underlyingChannel: withExtendedLifetime(argumentWrappers) { var argumentValues = argumentWrappers.map { $0.wrapped } if secure { return cgrpc_channel_create_secure(address, kRootCertificates, nil, nil, &argumentValues, Int32(arguments.count)) } else { return cgrpc_channel_create(address, &argumentValues, Int32(arguments.count)) } - } - completionQueue = CompletionQueue(underlyingCompletionQueue: cgrpc_channel_completion_queue(underlyingChannel), name: "Client") - completionQueue.run() // start a loop that watches the channel's completion queue + }) } /// Initializes a gRPC channel /// /// - Parameter address: the address of the server to be called /// - Parameter arguments: list of channel configuration options - public init(googleAddress: String, arguments: [Argument] = []) { + public convenience init(googleAddress: String, arguments: [Argument] = []) { gRPC.initialize() - host = googleAddress - let argumentWrappers = arguments.map { $0.toCArg() } - underlyingChannel = withExtendedLifetime(argumentWrappers) { + let argumentWrappers = arguments.map { $0.toCArg() } + self.init(host: googleAddress, underlyingChannel: withExtendedLifetime(argumentWrappers) { var argumentValues = argumentWrappers.map { $0.wrapped } return cgrpc_channel_create_google(googleAddress, &argumentValues, Int32(arguments.count)) - } - - completionQueue = CompletionQueue(underlyingCompletionQueue: cgrpc_channel_completion_queue(underlyingChannel), name: "Client") - completionQueue.run() // start a loop that watches the channel's completion queue + }) } /// Initializes a gRPC channel @@ -81,25 +96,31 @@ public class Channel { /// - Parameter clientCertificates: a PEM representation of the client certificates to use /// - Parameter clientKey: a PEM representation of the client key to use /// - Parameter arguments: list of channel configuration options - public init(address: String, certificates: String = kRootCertificates, clientCertificates: String? = nil, clientKey: String? = nil, arguments: [Argument] = []) { + public convenience init(address: String, certificates: String = kRootCertificates, clientCertificates: String? = nil, clientKey: String? = nil, arguments: [Argument] = []) { gRPC.initialize() - host = address - let argumentWrappers = arguments.map { $0.toCArg() } - underlyingChannel = withExtendedLifetime(argumentWrappers) { + let argumentWrappers = arguments.map { $0.toCArg() } + self.init(host: address, underlyingChannel: withExtendedLifetime(argumentWrappers) { var argumentValues = argumentWrappers.map { $0.wrapped } return cgrpc_channel_create_secure(address, certificates, clientCertificates, clientKey, &argumentValues, Int32(arguments.count)) - } - completionQueue = CompletionQueue(underlyingCompletionQueue: cgrpc_channel_completion_queue(underlyingChannel), name: "Client") - completionQueue.run() // start a loop that watches the channel's completion queue + }) } - deinit { + /// Shut down the channel. No new calls may be made using this channel after it is shut down. Any in-flight calls using this channel will be canceled + public func shutdown() { self.mutex.synchronize { + guard !self.hasBeenShutdown else { return } + + self.hasBeenShutdown = true self.connectivityObserver?.shutdown() + cgrpc_channel_destroy(self.underlyingChannel) + self.completionQueue.shutdown() + self.activeCalls.forEach { $0.value?.cancel() } } - cgrpc_channel_destroy(self.underlyingChannel) - self.completionQueue.shutdown() + } + + deinit { + self.shutdown() } /// Constructs a Call object to make a gRPC API call @@ -108,11 +129,21 @@ public class Channel { /// - Parameter host: the gRPC host name for the call. If unspecified, defaults to the Client host /// - Parameter timeout: a timeout value in seconds /// - Returns: a Call object that can be used to perform the request - public func makeCall(_ method: String, host: String = "", timeout: TimeInterval? = nil) -> Call { - let host = host.isEmpty ? self.host : host - let timeout = timeout ?? self.timeout - let underlyingCall = cgrpc_channel_create_call(underlyingChannel, method, host, timeout)! - return Call(underlyingCall: underlyingCall, owned: true, completionQueue: completionQueue) + public func makeCall(_ method: String, host: String? = nil, timeout: TimeInterval? = nil) throws -> Call { + self.mutex.lock() + defer { self.mutex.unlock() } + + guard !self.hasBeenShutdown else { + throw Error.alreadyShutdown + } + + guard let underlyingCall = cgrpc_channel_create_call( + self.underlyingChannel, method, host ?? self.host, timeout ?? self.timeout) + else { throw Error.callCreationFailed } + + let call = Call(underlyingCall: underlyingCall, owned: true, completionQueue: self.completionQueue) + self.activeCalls.append(WeakReference(value: call)) + return call } /// Check the current connectivity state @@ -139,4 +170,29 @@ public class Channel { observer.addConnectivityObserver(callback: callback) } } + + // MARK: - Private + + private init(host: String, underlyingChannel: UnsafeMutableRawPointer) { + self.host = host + self.underlyingChannel = underlyingChannel + self.completionQueue = CompletionQueue(underlyingCompletionQueue: cgrpc_channel_completion_queue(underlyingChannel), + name: "Client") + + self.completionQueue.run() + self.scheduleActiveCallCleanUp() + } + + private func scheduleActiveCallCleanUp() { + DispatchQueue.global(qos: .background).asyncAfter(deadline: .now() + 10.0) { [weak self] in + self?.cleanUpActiveCalls() + } + } + + private func cleanUpActiveCalls() { + self.mutex.synchronize { + self.activeCalls = self.activeCalls.filter { $0.value != nil } + } + self.scheduleActiveCallCleanUp() + } } diff --git a/Sources/SwiftGRPC/Runtime/ClientCall.swift b/Sources/SwiftGRPC/Runtime/ClientCall.swift index 8473c92a1..48f4725d2 100644 --- a/Sources/SwiftGRPC/Runtime/ClientCall.swift +++ b/Sources/SwiftGRPC/Runtime/ClientCall.swift @@ -13,27 +13,28 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -import Dispatch -import Foundation import SwiftProtobuf public protocol ClientCall: class { static var method: String { get } - + /// Cancel the call. func cancel() } -open class ClientCallBase: ClientCall { +open class ClientCallBase { open class var method: String { fatalError("needs to be overridden") } public let call: Call /// Create a call. - public init(_ channel: Channel) { - call = channel.makeCall(type(of: self).method) + public init(_ channel: Channel) throws { + self.call = try channel.makeCall(type(of: self).method) + } +} + +extension ClientCallBase: ClientCall { + public func cancel() { + self.call.cancel() } - - public func cancel() { call.cancel() } } diff --git a/Tests/LinuxMain.swift b/Tests/LinuxMain.swift index 91d0a9807..0cc7c5c78 100644 --- a/Tests/LinuxMain.swift +++ b/Tests/LinuxMain.swift @@ -22,6 +22,7 @@ XCTMain([ testCase(gRPCTests.allTests), testCase(ChannelArgumentTests.allTests), testCase(ChannelConnectivityTests.allTests), + testCase(ChannelShutdownTests.allTests), testCase(ClientCancellingTests.allTests), testCase(ClientTestExample.allTests), testCase(ClientTimeoutTests.allTests), diff --git a/Tests/SwiftGRPCTests/ChannelShutdownTests.swift b/Tests/SwiftGRPCTests/ChannelShutdownTests.swift new file mode 100644 index 000000000..40602a069 --- /dev/null +++ b/Tests/SwiftGRPCTests/ChannelShutdownTests.swift @@ -0,0 +1,93 @@ +/* + * Copyright 2018, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +@testable import SwiftGRPC +import XCTest + +final class ChannelShutdownTests: BasicEchoTestCase { + static var allTests: [(String, (ChannelShutdownTests) -> () throws -> Void)] { + return [ + ("testThrowsWhenCreatingCallWithAlreadyShutDownChannel", testThrowsWhenCreatingCallWithAlreadyShutDownChannel), + ("testCallReceiveThrowsWhenChannelIsShutDown", testCallReceiveThrowsWhenChannelIsShutDown), + ("testCallCloseThrowsWhenChannelIsShutDown", testCallCloseThrowsWhenChannelIsShutDown), + ("testCallCloseAndReceiveThrowsWhenChannelIsShutDown", testCallCloseAndReceiveThrowsWhenChannelIsShutDown), + ("testCallSendThrowsWhenChannelIsShutDown", testCallSendThrowsWhenChannelIsShutDown), + ("testCancelsActiveCallWhenShutdownIsCalled", testCancelsActiveCallWhenShutdownIsCalled), + ] + } +} + +extension ChannelShutdownTests { + func testThrowsWhenCreatingCallWithAlreadyShutDownChannel() { + self.client.channel.shutdown() + + XCTAssertThrowsError(try self.client.channel.makeCall("foobar")) { error in + XCTAssertEqual(.alreadyShutdown, error as? Channel.Error) + } + } + + func testCallReceiveThrowsWhenChannelIsShutDown() { + let call = try! self.client.channel.makeCall("foo") + self.client.channel.shutdown() + + XCTAssertThrowsError(try call.receiveMessage { _ in }) { error in + XCTAssertEqual(.completionQueueShutdown, error as? CallError) + } + } + + func testCallCloseThrowsWhenChannelIsShutDown() { + let call = try! self.client.channel.makeCall("foo") + self.client.channel.shutdown() + + XCTAssertThrowsError(try call.close()) { error in + XCTAssertEqual(.completionQueueShutdown, error as? CallError) + } + } + + func testCallCloseAndReceiveThrowsWhenChannelIsShutDown() { + let call = try! self.client.channel.makeCall("foo") + self.client.channel.shutdown() + + XCTAssertThrowsError(try call.closeAndReceiveMessage { _ in }) { error in + XCTAssertEqual(.completionQueueShutdown, error as? CallError) + } + } + + func testCallSendThrowsWhenChannelIsShutDown() { + let call = try! self.client.channel.makeCall("foo") + self.client.channel.shutdown() + + XCTAssertThrowsError(try call.sendMessage(data: Data())) { error in + XCTAssertEqual(.completionQueueShutdown, error as? CallError) + } + } + + func testCancelsActiveCallWhenShutdownIsCalled() { + let errorExpectation = self.expectation(description: "error is returned to call when channel is shut down") + let call = try! self.client.channel.makeCall("foo") + + try! call.receiveMessage { result in + XCTAssertFalse(result.success) + errorExpectation.fulfill() + } + + self.client.channel.shutdown() + self.waitForExpectations(timeout: 0.1) + + XCTAssertThrowsError(try call.close()) { error in + XCTAssertEqual(.completionQueueShutdown, error as? CallError) + } + } +} diff --git a/Tests/SwiftGRPCTests/GRPCTests.swift b/Tests/SwiftGRPCTests/GRPCTests.swift index aadd4b8df..546fb37f8 100644 --- a/Tests/SwiftGRPCTests/GRPCTests.swift +++ b/Tests/SwiftGRPCTests/GRPCTests.swift @@ -177,7 +177,7 @@ func callUnary(channel: Channel) throws { func callUnaryIndividual(channel: Channel, message: Data, shouldSucceed: Bool) throws { let sem = DispatchSemaphore(value: 0) let method = hello - let call = channel.makeCall(method) + let call = try channel.makeCall(method) let metadata = try Metadata(initialClientMetadata) try call.start(.unary, metadata: metadata, message: message) { response in @@ -228,7 +228,7 @@ func callServerStream(channel: Channel) throws { let sem = DispatchSemaphore(value: 0) let method = helloServerStream - let call = channel.makeCall(method) + let call = try channel.makeCall(method) try call.start(.serverStreaming, metadata: metadata, message: message) { response in @@ -270,7 +270,7 @@ func callBiDiStream(channel: Channel) throws { let sem = DispatchSemaphore(value: 0) let method = helloBiDiStream - let call = channel.makeCall(method) + let call = try channel.makeCall(method) try call.start(.bidiStreaming, metadata: metadata, message: nil) { response in