diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 3b571a23a..438d817e6 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -25,7 +25,8 @@ jobs: matrix: include: - image: swift:5.6-focal - swift-test-flags: "--sanitize=thread" + # No TSAN because of: https://github.com/apple/swift/issues/59068 + # swift-test-flags: "--sanitize=thread" - image: swift:5.5-focal swift-test-flags: "--sanitize=thread" - image: swift:5.4-focal diff --git a/Sources/GRPC/AsyncAwaitSupport/AsyncWriter.swift b/Sources/GRPC/AsyncAwaitSupport/AsyncWriter.swift index 49d85aa7b..2159a4ad0 100644 --- a/Sources/GRPC/AsyncAwaitSupport/AsyncWriter.swift +++ b/Sources/GRPC/AsyncAwaitSupport/AsyncWriter.swift @@ -109,7 +109,7 @@ internal final actor AsyncWriter: Sendable { /// Whether the writer is paused. @usableFromInline - internal var _isPaused: Bool = false + internal var _isPaused: Bool /// The delegate to process elements. By convention we call the delegate before resuming any /// continuation. @@ -120,12 +120,14 @@ internal final actor AsyncWriter: Sendable { internal init( maxPendingElements: Int = 16, maxWritesBeforeYield: Int = 5, + isWritable: Bool = true, delegate: Delegate ) { self._maxPendingElements = maxPendingElements self._maxWritesBeforeYield = maxWritesBeforeYield self._pendingElements = CircularBuffer(initialCapacity: maxPendingElements) self._completionState = .incomplete + self._isPaused = !isWritable self._delegate = delegate } diff --git a/Sources/GRPC/AsyncAwaitSupport/Call+AsyncRequestStreamWriter.swift b/Sources/GRPC/AsyncAwaitSupport/Call+AsyncRequestStreamWriter.swift index 4b833410a..278098d2c 100644 --- a/Sources/GRPC/AsyncAwaitSupport/Call+AsyncRequestStreamWriter.swift +++ b/Sources/GRPC/AsyncAwaitSupport/Call+AsyncRequestStreamWriter.swift @@ -26,7 +26,8 @@ extension Call { self.send(.end, promise: nil) } - return GRPCAsyncRequestStreamWriter(asyncWriter: .init(delegate: delegate)) + // Start as not-writable; writability will be toggled when the stream comes up. + return GRPCAsyncRequestStreamWriter(asyncWriter: .init(isWritable: false, delegate: delegate)) } } diff --git a/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncBidirectionalStreamingCall.swift b/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncBidirectionalStreamingCall.swift index 3a7cfd315..59cfd4ac6 100644 --- a/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncBidirectionalStreamingCall.swift +++ b/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncBidirectionalStreamingCall.swift @@ -86,6 +86,9 @@ public struct GRPCAsyncBidirectionalStreamingCall: Sendabl self.responseParts = UnaryResponseParts(on: call.eventLoop) self.call.invokeUnaryRequest( request, + onStart: {}, onError: self.responseParts.handleError(_:), onResponsePart: self.responseParts.handle(_:) ) diff --git a/Sources/GRPC/ClientCalls/BidirectionalStreamingCall.swift b/Sources/GRPC/ClientCalls/BidirectionalStreamingCall.swift index a1dd58c39..e14c8939a 100644 --- a/Sources/GRPC/ClientCalls/BidirectionalStreamingCall.swift +++ b/Sources/GRPC/ClientCalls/BidirectionalStreamingCall.swift @@ -84,6 +84,7 @@ public struct BidirectionalStreamingCall< internal func invoke() { self.call.invokeStreamingRequests( + onStart: {}, onError: self.responseParts.handleError(_:), onResponsePart: self.responseParts.handle(_:) ) diff --git a/Sources/GRPC/ClientCalls/Call.swift b/Sources/GRPC/ClientCalls/Call.swift index 9244a1687..641622fbb 100644 --- a/Sources/GRPC/ClientCalls/Call.swift +++ b/Sources/GRPC/ClientCalls/Call.swift @@ -123,10 +123,10 @@ public final class Call { self.options.logger.debug("starting rpc", metadata: ["path": "\(self.path)"], source: "GRPC") if self.eventLoop.inEventLoop { - self._invoke(onError: onError, onResponsePart: onResponsePart) + self._invoke(onStart: {}, onError: onError, onResponsePart: onResponsePart) } else { self.eventLoop.execute { - self._invoke(onError: onError, onResponsePart: onResponsePart) + self._invoke(onStart: {}, onError: onError, onResponsePart: onResponsePart) } } } @@ -262,6 +262,7 @@ extension Call { /// - Important: This *must* to be called from the `eventLoop`. @usableFromInline internal func _invoke( + onStart: @escaping () -> Void, onError: @escaping (Error) -> Void, onResponsePart: @escaping (GRPCClientResponsePart) -> Void ) { @@ -275,6 +276,7 @@ extension Call { withOptions: self.options, onEventLoop: self.eventLoop, interceptedBy: self._interceptors, + onStart: onStart, onError: onError, onResponsePart: onResponsePart ) @@ -354,14 +356,25 @@ extension Call { @inlinable internal func invokeUnaryRequest( _ request: Request, + onStart: @escaping () -> Void, onError: @escaping (Error) -> Void, onResponsePart: @escaping (GRPCClientResponsePart) -> Void ) { if self.eventLoop.inEventLoop { - self._invokeUnaryRequest(request: request, onError: onError, onResponsePart: onResponsePart) + self._invokeUnaryRequest( + request: request, + onStart: onStart, + onError: onError, + onResponsePart: onResponsePart + ) } else { self.eventLoop.execute { - self._invokeUnaryRequest(request: request, onError: onError, onResponsePart: onResponsePart) + self._invokeUnaryRequest( + request: request, + onStart: onStart, + onError: onError, + onResponsePart: onResponsePart + ) } } } @@ -373,14 +386,23 @@ extension Call { /// - onResponsePart: A callback invoked for each response part received. @inlinable internal func invokeStreamingRequests( + onStart: @escaping () -> Void, onError: @escaping (Error) -> Void, onResponsePart: @escaping (GRPCClientResponsePart) -> Void ) { if self.eventLoop.inEventLoop { - self._invokeStreamingRequests(onError: onError, onResponsePart: onResponsePart) + self._invokeStreamingRequests( + onStart: onStart, + onError: onError, + onResponsePart: onResponsePart + ) } else { self.eventLoop.execute { - self._invokeStreamingRequests(onError: onError, onResponsePart: onResponsePart) + self._invokeStreamingRequests( + onStart: onStart, + onError: onError, + onResponsePart: onResponsePart + ) } } } @@ -389,13 +411,14 @@ extension Call { @usableFromInline internal func _invokeUnaryRequest( request: Request, + onStart: @escaping () -> Void, onError: @escaping (Error) -> Void, onResponsePart: @escaping (GRPCClientResponsePart) -> Void ) { self.eventLoop.assertInEventLoop() assert(self.type == .unary || self.type == .serverStreaming) - self._invoke(onError: onError, onResponsePart: onResponsePart) + self._invoke(onStart: onStart, onError: onError, onResponsePart: onResponsePart) self._send(.metadata(self.options.customMetadata), promise: nil) self._send( .message(request, .init(compress: self.isCompressionEnabled, flush: false)), @@ -407,13 +430,14 @@ extension Call { /// On-`EventLoop` implementation of `invokeStreamingRequests(_:)`. @usableFromInline internal func _invokeStreamingRequests( + onStart: @escaping () -> Void, onError: @escaping (Error) -> Void, onResponsePart: @escaping (GRPCClientResponsePart) -> Void ) { self.eventLoop.assertInEventLoop() assert(self.type == .clientStreaming || self.type == .bidirectionalStreaming) - self._invoke(onError: onError, onResponsePart: onResponsePart) + self._invoke(onStart: onStart, onError: onError, onResponsePart: onResponsePart) self._send(.metadata(self.options.customMetadata), promise: nil) } } diff --git a/Sources/GRPC/ClientCalls/ClientStreamingCall.swift b/Sources/GRPC/ClientCalls/ClientStreamingCall.swift index 52362b04f..03bd534ca 100644 --- a/Sources/GRPC/ClientCalls/ClientStreamingCall.swift +++ b/Sources/GRPC/ClientCalls/ClientStreamingCall.swift @@ -84,6 +84,7 @@ public struct ClientStreamingCall: StreamingReq internal func invoke() { self.call.invokeStreamingRequests( + onStart: {}, onError: self.responseParts.handleError(_:), onResponsePart: self.responseParts.handle(_:) ) diff --git a/Sources/GRPC/ClientCalls/ServerStreamingCall.swift b/Sources/GRPC/ClientCalls/ServerStreamingCall.swift index 26799cdff..67b93226e 100644 --- a/Sources/GRPC/ClientCalls/ServerStreamingCall.swift +++ b/Sources/GRPC/ClientCalls/ServerStreamingCall.swift @@ -80,6 +80,7 @@ public struct ServerStreamingCall: ClientCall { internal func invoke(_ request: RequestPayload) { self.call.invokeUnaryRequest( request, + onStart: {}, onError: self.responseParts.handleError(_:), onResponsePart: self.responseParts.handle(_:) ) diff --git a/Sources/GRPC/ClientCalls/UnaryCall.swift b/Sources/GRPC/ClientCalls/UnaryCall.swift index fca720e6d..5cb7a712a 100644 --- a/Sources/GRPC/ClientCalls/UnaryCall.swift +++ b/Sources/GRPC/ClientCalls/UnaryCall.swift @@ -84,6 +84,7 @@ public struct UnaryCall: UnaryResponseClientCal internal func invoke(_ request: RequestPayload) { self.call.invokeUnaryRequest( request, + onStart: {}, onError: self.responseParts.handleError(_:), onResponsePart: self.responseParts.handle(_:) ) diff --git a/Sources/GRPC/Interceptor/ClientTransport.swift b/Sources/GRPC/Interceptor/ClientTransport.swift index 9562b9261..49b7dbfb1 100644 --- a/Sources/GRPC/Interceptor/ClientTransport.swift +++ b/Sources/GRPC/Interceptor/ClientTransport.swift @@ -93,6 +93,9 @@ internal final class ClientTransport { /// The `NIO.Channel` used by the transport, if it is available. private var channel: Channel? + /// A callback which is invoked once when the stream channel becomes active. + private let onStart: () -> Void + /// Our current state as logging metadata. private var stateForLogging: Logger.MetadataValue { if self.state.mayBuffer { @@ -109,11 +112,13 @@ internal final class ClientTransport { serializer: AnySerializer, deserializer: AnyDeserializer, errorDelegate: ClientErrorDelegate?, + onStart: @escaping () -> Void, onError: @escaping (Error) -> Void, onResponsePart: @escaping (GRPCClientResponsePart) -> Void ) { self.callEventLoop = eventLoop self.callDetails = details + self.onStart = onStart let logger = GRPCLogger(wrapping: details.options.logger) self.logger = logger self.serializer = serializer @@ -332,6 +337,7 @@ extension ClientTransport { self._pipeline?.logger = self.logger self.logger.debug("activated stream channel") self.channel = channel + self.onStart() self.unbuffer() case .close: diff --git a/Sources/GRPC/Interceptor/ClientTransportFactory.swift b/Sources/GRPC/Interceptor/ClientTransportFactory.swift index f5e597b76..2010fbe9a 100644 --- a/Sources/GRPC/Interceptor/ClientTransportFactory.swift +++ b/Sources/GRPC/Interceptor/ClientTransportFactory.swift @@ -140,6 +140,7 @@ internal struct ClientTransportFactory { withOptions options: CallOptions, onEventLoop eventLoop: EventLoop, interceptedBy interceptors: [ClientInterceptor], + onStart: @escaping () -> Void, onError: @escaping (Error) -> Void, onResponsePart: @escaping (GRPCClientResponsePart) -> Void ) -> ClientTransport { @@ -151,6 +152,7 @@ internal struct ClientTransportFactory { withOptions: options, onEventLoop: eventLoop, interceptedBy: interceptors, + onStart: onStart, onError: onError, onResponsePart: onResponsePart ) @@ -220,6 +222,7 @@ internal struct HTTP2ClientTransportFactory { withOptions options: CallOptions, onEventLoop eventLoop: EventLoop, interceptedBy interceptors: [ClientInterceptor], + onStart: @escaping () -> Void, onError: @escaping (Error) -> Void, onResponsePart: @escaping (GRPCClientResponsePart) -> Void ) -> ClientTransport { @@ -230,6 +233,7 @@ internal struct HTTP2ClientTransportFactory { serializer: self.serializer, deserializer: self.deserializer, errorDelegate: self.errorDelegate, + onStart: onStart, onError: onError, onResponsePart: onResponsePart ) @@ -333,6 +337,7 @@ internal struct FakeClientTransportFactory { serializer: self.requestSerializer, deserializer: self.responseDeserializer, errorDelegate: nil, + onStart: {}, onError: onError, onResponsePart: onResponsePart ) diff --git a/Tests/GRPCTests/AsyncAwaitSupport/AsyncClientTests.swift b/Tests/GRPCTests/AsyncAwaitSupport/AsyncClientTests.swift index 4360ccd75..6362c1dbc 100644 --- a/Tests/GRPCTests/AsyncAwaitSupport/AsyncClientTests.swift +++ b/Tests/GRPCTests/AsyncAwaitSupport/AsyncClientTests.swift @@ -27,17 +27,21 @@ final class AsyncClientCancellationTests: GRPCTestCase { private var group: EventLoopGroup! private var pool: GRPCChannel! - override func setUpWithError() throws { - try super.setUpWithError() + override func setUp() { + super.setUp() self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1) } override func tearDown() async throws { - try self.pool.close().wait() - self.pool = nil + if self.pool != nil { + try self.pool.close().wait() + self.pool = nil + } - try self.server.close().wait() - self.server = nil + if self.server != nil { + try self.server.close().wait() + self.server = nil + } try self.group.syncShutdownGracefully() self.group = nil @@ -45,18 +49,26 @@ final class AsyncClientCancellationTests: GRPCTestCase { try await super.tearDown() } - private func startServer(service: CallHandlerProvider) throws -> Echo_EchoAsyncClient { + private func startServer(service: CallHandlerProvider) throws { precondition(self.server == nil) - precondition(self.pool == nil) self.server = try Server.insecure(group: self.group) .withServiceProviders([service]) .withLogger(self.serverLogger) .bind(host: "127.0.0.1", port: 0) .wait() + } + + private func startServerAndClient(service: CallHandlerProvider) throws -> Echo_EchoAsyncClient { + try self.startServer(service: service) + return try self.makeClient(port: self.server.channel.localAddress!.port!) + } + + private func makeClient(port: Int) throws -> Echo_EchoAsyncClient { + precondition(self.pool == nil) self.pool = try GRPCChannelPool.with( - target: .host("127.0.0.1", port: self.server.channel.localAddress!.port!), + target: .host("127.0.0.1", port: port), transportSecurity: .plaintext, eventLoopGroup: self.group ) { @@ -68,7 +80,7 @@ final class AsyncClientCancellationTests: GRPCTestCase { func testCancelUnaryFailsResponse() async throws { // We don't want the RPC to complete before we cancel it so use the never resolving service. - let echo = try self.startServer(service: NeverResolvingEchoProvider()) + let echo = try self.startServerAndClient(service: NeverResolvingEchoProvider()) let get = echo.makeGetCall(.with { $0.text = "foo bar baz" }) try await get.cancel() @@ -82,7 +94,7 @@ final class AsyncClientCancellationTests: GRPCTestCase { func testCancelServerStreamingClosesResponseStream() async throws { // We don't want the RPC to complete before we cancel it so use the never resolving service. - let echo = try self.startServer(service: NeverResolvingEchoProvider()) + let echo = try self.startServerAndClient(service: NeverResolvingEchoProvider()) let expand = echo.makeExpandCall(.with { $0.text = "foo bar baz" }) try await expand.cancel() @@ -96,7 +108,7 @@ final class AsyncClientCancellationTests: GRPCTestCase { } func testCancelClientStreamingClosesRequestStreamAndFailsResponse() async throws { - let echo = try self.startServer(service: EchoProvider()) + let echo = try self.startServerAndClient(service: EchoProvider()) let collect = echo.makeCollectCall() // Make sure the stream is up before we cancel it. @@ -114,7 +126,7 @@ final class AsyncClientCancellationTests: GRPCTestCase { } func testClientStreamingClosesRequestStreamOnEnd() async throws { - let echo = try self.startServer(service: EchoProvider()) + let echo = try self.startServerAndClient(service: EchoProvider()) let collect = echo.makeCollectCall() // Send and close. @@ -133,7 +145,7 @@ final class AsyncClientCancellationTests: GRPCTestCase { } func testCancelBidiStreamingClosesRequestStreamAndResponseStream() async throws { - let echo = try self.startServer(service: EchoProvider()) + let echo = try self.startServerAndClient(service: EchoProvider()) let update = echo.makeUpdateCall() // Make sure the stream is up before we cancel it. @@ -153,7 +165,7 @@ final class AsyncClientCancellationTests: GRPCTestCase { } func testBidiStreamingClosesRequestStreamOnEnd() async throws { - let echo = try self.startServer(service: EchoProvider()) + let echo = try self.startServerAndClient(service: EchoProvider()) let update = echo.makeUpdateCall() // Send and close. @@ -172,6 +184,98 @@ final class AsyncClientCancellationTests: GRPCTestCase { try await update.requestStream.send(.with { $0.text = "should throw" }) ) } + + private enum RequestStreamingRPC { + typealias Request = Echo_EchoRequest + typealias Response = Echo_EchoResponse + + case clientStreaming(GRPCAsyncClientStreamingCall) + case bidirectionalStreaming(GRPCAsyncBidirectionalStreamingCall) + + func sendRequest(_ text: String) async throws { + switch self { + case let .clientStreaming(call): + try await call.requestStream.send(.with { $0.text = text }) + case let .bidirectionalStreaming(call): + try await call.requestStream.send(.with { $0.text = text }) + } + } + + func cancel() { + switch self { + case let .clientStreaming(call): + // TODO: this should be async + Task { try await call.cancel() } + case let .bidirectionalStreaming(call): + // TODO: this should be async + Task { try await call.cancel() } + } + } + } + + private func testSendingRequestsSuspendsWhileStreamIsNotReady( + makeRPC: @escaping () -> RequestStreamingRPC + ) async throws { + // The strategy for this test is to race two different tasks. The first will attempt to send a + // message on a request stream on a connection which will never establish. The second will sleep + // for a little while. Each task returns a `SendOrTimedOut` event. If the message is sent then + // the test definitely failed; it should not be possible to send a message on a stream which is + // not open. If the time out happens first then it probably did not fail. + enum SentOrTimedOut: Equatable, Sendable { + case messageSent + case timedOut + } + + await withThrowingTaskGroup(of: SentOrTimedOut.self) { group in + group.addTask { + let rpc = makeRPC() + + return try await withTaskCancellationHandler { + // This should suspend until we cancel it: we're never going to start a server so it + // should never succeed. + try await rpc.sendRequest("I should suspend") + return .messageSent + } onCancel: { + rpc.cancel() + } + } + + group.addTask { + // Wait for 100ms. + try await Task.sleep(nanoseconds: 100_000_000) + return .timedOut + } + + do { + let event = try await group.next() + // If this isn't timed out then the message was sent before the stream was ready. + XCTAssertEqual(event, .timedOut) + } catch { + XCTFail("Unexpected error \(error)") + } + + // Cancel the other task. + group.cancelAll() + } + } + + func testClientStreamingSuspendsWritesUntilStreamIsUp() async throws { + // Make a client for a server which isn't up yet. It will continually fail to establish a + // connection. + let echo = try self.makeClient(port: 0) + try await self.testSendingRequestsSuspendsWhileStreamIsNotReady { + return .clientStreaming(echo.makeCollectCall()) + } + } + + func testBidirectionalStreamingSuspendsWritesUntilStreamIsUp() async throws { + // Make a client for a server which isn't up yet. It will continually fail to establish a + // connection. + let echo = try self.makeClient(port: 0) + try await self.testSendingRequestsSuspendsWhileStreamIsNotReady { + return .bidirectionalStreaming(echo.makeUpdateCall()) + } + } } #endif // compiler(>=5.6) diff --git a/Tests/GRPCTests/ClientCallTests.swift b/Tests/GRPCTests/ClientCallTests.swift index fce9d3f15..219d60005 100644 --- a/Tests/GRPCTests/ClientCallTests.swift +++ b/Tests/GRPCTests/ClientCallTests.swift @@ -122,6 +122,7 @@ class ClientCallTests: GRPCTestCase { let promise = self.makeStatusPromise() get.invokeUnaryRequest( .with { $0.text = "get" }, + onStart: {}, onError: promise.fail(_:), onResponsePart: self.makeResponsePartHandler(completing: promise) ) @@ -134,6 +135,7 @@ class ClientCallTests: GRPCTestCase { let promise = self.makeStatusPromise() collect.invokeStreamingRequests( + onStart: {}, onError: promise.fail(_:), onResponsePart: self.makeResponsePartHandler(completing: promise) ) @@ -152,6 +154,7 @@ class ClientCallTests: GRPCTestCase { let promise = self.makeStatusPromise() expand.invokeUnaryRequest( .with { $0.text = "expand" }, + onStart: {}, onError: promise.fail(_:), onResponsePart: self.makeResponsePartHandler(completing: promise) ) @@ -164,6 +167,7 @@ class ClientCallTests: GRPCTestCase { let promise = self.makeStatusPromise() update.invokeStreamingRequests( + onStart: {}, onError: promise.fail(_:), onResponsePart: self.makeResponsePartHandler(completing: promise) ) diff --git a/Tests/GRPCTests/ClientTransportTests.swift b/Tests/GRPCTests/ClientTransportTests.swift index ee40ba153..856e69c30 100644 --- a/Tests/GRPCTests/ClientTransportTests.swift +++ b/Tests/GRPCTests/ClientTransportTests.swift @@ -56,6 +56,7 @@ class ClientTransportTests: GRPCTestCase { serializer: AnySerializer(wrapping: StringSerializer()), deserializer: AnyDeserializer(wrapping: StringDeserializer()), errorDelegate: nil, + onStart: {}, onError: onError, onResponsePart: onResponsePart )