Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Suspend request stream writes before the RPC is ready #1411

Merged
merged 2 commits into from
May 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ jobs:
matrix:
include:
- image: swift:5.6-focal
swift-test-flags: "--sanitize=thread"
# No TSAN because of: https://github.com/apple/swift/issues/59068
# swift-test-flags: "--sanitize=thread"
- image: swift:5.5-focal
swift-test-flags: "--sanitize=thread"
- image: swift:5.4-focal
Expand Down
4 changes: 3 additions & 1 deletion Sources/GRPC/AsyncAwaitSupport/AsyncWriter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ internal final actor AsyncWriter<Delegate: AsyncWriterDelegate>: Sendable {

/// Whether the writer is paused.
@usableFromInline
internal var _isPaused: Bool = false
internal var _isPaused: Bool

/// The delegate to process elements. By convention we call the delegate before resuming any
/// continuation.
Expand All @@ -120,12 +120,14 @@ internal final actor AsyncWriter<Delegate: AsyncWriterDelegate>: Sendable {
internal init(
maxPendingElements: Int = 16,
maxWritesBeforeYield: Int = 5,
isWritable: Bool = true,
delegate: Delegate
) {
self._maxPendingElements = maxPendingElements
self._maxWritesBeforeYield = maxWritesBeforeYield
self._pendingElements = CircularBuffer(initialCapacity: maxPendingElements)
self._completionState = .incomplete
self._isPaused = !isWritable
self._delegate = delegate
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ extension Call {
self.send(.end, promise: nil)
}

return GRPCAsyncRequestStreamWriter(asyncWriter: .init(delegate: delegate))
// Start as not-writable; writability will be toggled when the stream comes up.
return GRPCAsyncRequestStreamWriter(asyncWriter: .init(isWritable: false, delegate: delegate))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ public struct GRPCAsyncBidirectionalStreamingCall<Request: Sendable, Response: S
let asyncCall = Self(call: call)

asyncCall.call.invokeStreamingRequests(
onStart: {
asyncCall.requestStream.asyncWriter.toggleWritabilityAsynchronously()
},
onError: { error in
asyncCall.responseParts.handleError(error)
asyncCall.responseSource.finish(throwing: error)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ public struct GRPCAsyncClientStreamingCall<Request: Sendable, Response: Sendable
let asyncCall = Self(call: call)

asyncCall.call.invokeStreamingRequests(
onStart: {
asyncCall.requestStream.asyncWriter.toggleWritabilityAsynchronously()
},
onError: { error in
asyncCall.responseParts.handleError(error)
asyncCall.requestStream.asyncWriter.cancelAsynchronously()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ public struct GRPCAsyncServerStreamingCall<Request: Sendable, Response: Sendable

asyncCall.call.invokeUnaryRequest(
request,
onStart: {},
onError: { error in
asyncCall.responseParts.handleError(error)
asyncCall.responseSource.finish(throwing: error)
Expand Down
1 change: 1 addition & 0 deletions Sources/GRPC/AsyncAwaitSupport/GRPCAsyncUnaryCall.swift
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ public struct GRPCAsyncUnaryCall<Request: Sendable, Response: Sendable>: Sendabl
self.responseParts = UnaryResponseParts(on: call.eventLoop)
self.call.invokeUnaryRequest(
request,
onStart: {},
onError: self.responseParts.handleError(_:),
onResponsePart: self.responseParts.handle(_:)
)
Expand Down
1 change: 1 addition & 0 deletions Sources/GRPC/ClientCalls/BidirectionalStreamingCall.swift
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ public struct BidirectionalStreamingCall<

internal func invoke() {
self.call.invokeStreamingRequests(
onStart: {},
onError: self.responseParts.handleError(_:),
onResponsePart: self.responseParts.handle(_:)
)
Expand Down
40 changes: 32 additions & 8 deletions Sources/GRPC/ClientCalls/Call.swift
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ public final class Call<Request, Response> {
self.options.logger.debug("starting rpc", metadata: ["path": "\(self.path)"], source: "GRPC")

if self.eventLoop.inEventLoop {
self._invoke(onError: onError, onResponsePart: onResponsePart)
self._invoke(onStart: {}, onError: onError, onResponsePart: onResponsePart)
} else {
self.eventLoop.execute {
self._invoke(onError: onError, onResponsePart: onResponsePart)
self._invoke(onStart: {}, onError: onError, onResponsePart: onResponsePart)
}
}
}
Expand Down Expand Up @@ -262,6 +262,7 @@ extension Call {
/// - Important: This *must* to be called from the `eventLoop`.
@usableFromInline
internal func _invoke(
onStart: @escaping () -> Void,
onError: @escaping (Error) -> Void,
onResponsePart: @escaping (GRPCClientResponsePart<Response>) -> Void
) {
Expand All @@ -275,6 +276,7 @@ extension Call {
withOptions: self.options,
onEventLoop: self.eventLoop,
interceptedBy: self._interceptors,
onStart: onStart,
onError: onError,
onResponsePart: onResponsePart
)
Expand Down Expand Up @@ -354,14 +356,25 @@ extension Call {
@inlinable
internal func invokeUnaryRequest(
_ request: Request,
onStart: @escaping () -> Void,
onError: @escaping (Error) -> Void,
onResponsePart: @escaping (GRPCClientResponsePart<Response>) -> Void
) {
if self.eventLoop.inEventLoop {
self._invokeUnaryRequest(request: request, onError: onError, onResponsePart: onResponsePart)
self._invokeUnaryRequest(
request: request,
onStart: onStart,
onError: onError,
onResponsePart: onResponsePart
)
} else {
self.eventLoop.execute {
self._invokeUnaryRequest(request: request, onError: onError, onResponsePart: onResponsePart)
self._invokeUnaryRequest(
request: request,
onStart: onStart,
onError: onError,
onResponsePart: onResponsePart
)
}
}
}
Expand All @@ -373,14 +386,23 @@ extension Call {
/// - onResponsePart: A callback invoked for each response part received.
@inlinable
internal func invokeStreamingRequests(
onStart: @escaping () -> Void,
onError: @escaping (Error) -> Void,
onResponsePart: @escaping (GRPCClientResponsePart<Response>) -> Void
) {
if self.eventLoop.inEventLoop {
self._invokeStreamingRequests(onError: onError, onResponsePart: onResponsePart)
self._invokeStreamingRequests(
onStart: onStart,
onError: onError,
onResponsePart: onResponsePart
)
} else {
self.eventLoop.execute {
self._invokeStreamingRequests(onError: onError, onResponsePart: onResponsePart)
self._invokeStreamingRequests(
onStart: onStart,
onError: onError,
onResponsePart: onResponsePart
)
}
}
}
Expand All @@ -389,13 +411,14 @@ extension Call {
@usableFromInline
internal func _invokeUnaryRequest(
request: Request,
onStart: @escaping () -> Void,
onError: @escaping (Error) -> Void,
onResponsePart: @escaping (GRPCClientResponsePart<Response>) -> Void
) {
self.eventLoop.assertInEventLoop()
assert(self.type == .unary || self.type == .serverStreaming)

self._invoke(onError: onError, onResponsePart: onResponsePart)
self._invoke(onStart: onStart, onError: onError, onResponsePart: onResponsePart)
self._send(.metadata(self.options.customMetadata), promise: nil)
self._send(
.message(request, .init(compress: self.isCompressionEnabled, flush: false)),
Expand All @@ -407,13 +430,14 @@ extension Call {
/// On-`EventLoop` implementation of `invokeStreamingRequests(_:)`.
@usableFromInline
internal func _invokeStreamingRequests(
onStart: @escaping () -> Void,
onError: @escaping (Error) -> Void,
onResponsePart: @escaping (GRPCClientResponsePart<Response>) -> Void
) {
self.eventLoop.assertInEventLoop()
assert(self.type == .clientStreaming || self.type == .bidirectionalStreaming)

self._invoke(onError: onError, onResponsePart: onResponsePart)
self._invoke(onStart: onStart, onError: onError, onResponsePart: onResponsePart)
self._send(.metadata(self.options.customMetadata), promise: nil)
}
}
Expand Down
1 change: 1 addition & 0 deletions Sources/GRPC/ClientCalls/ClientStreamingCall.swift
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ public struct ClientStreamingCall<RequestPayload, ResponsePayload>: StreamingReq

internal func invoke() {
self.call.invokeStreamingRequests(
onStart: {},
onError: self.responseParts.handleError(_:),
onResponsePart: self.responseParts.handle(_:)
)
Expand Down
1 change: 1 addition & 0 deletions Sources/GRPC/ClientCalls/ServerStreamingCall.swift
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ public struct ServerStreamingCall<RequestPayload, ResponsePayload>: ClientCall {
internal func invoke(_ request: RequestPayload) {
self.call.invokeUnaryRequest(
request,
onStart: {},
onError: self.responseParts.handleError(_:),
onResponsePart: self.responseParts.handle(_:)
)
Expand Down
1 change: 1 addition & 0 deletions Sources/GRPC/ClientCalls/UnaryCall.swift
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ public struct UnaryCall<RequestPayload, ResponsePayload>: UnaryResponseClientCal
internal func invoke(_ request: RequestPayload) {
self.call.invokeUnaryRequest(
request,
onStart: {},
onError: self.responseParts.handleError(_:),
onResponsePart: self.responseParts.handle(_:)
)
Expand Down
6 changes: 6 additions & 0 deletions Sources/GRPC/Interceptor/ClientTransport.swift
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ internal final class ClientTransport<Request, Response> {
/// The `NIO.Channel` used by the transport, if it is available.
private var channel: Channel?

/// A callback which is invoked once when the stream channel becomes active.
private let onStart: () -> Void

/// Our current state as logging metadata.
private var stateForLogging: Logger.MetadataValue {
if self.state.mayBuffer {
Expand All @@ -109,11 +112,13 @@ internal final class ClientTransport<Request, Response> {
serializer: AnySerializer<Request>,
deserializer: AnyDeserializer<Response>,
errorDelegate: ClientErrorDelegate?,
onStart: @escaping () -> Void,
onError: @escaping (Error) -> Void,
onResponsePart: @escaping (GRPCClientResponsePart<Response>) -> Void
) {
self.callEventLoop = eventLoop
self.callDetails = details
self.onStart = onStart
let logger = GRPCLogger(wrapping: details.options.logger)
self.logger = logger
self.serializer = serializer
Expand Down Expand Up @@ -332,6 +337,7 @@ extension ClientTransport {
self._pipeline?.logger = self.logger
self.logger.debug("activated stream channel")
self.channel = channel
self.onStart()
self.unbuffer()

case .close:
Expand Down
5 changes: 5 additions & 0 deletions Sources/GRPC/Interceptor/ClientTransportFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ internal struct ClientTransportFactory<Request, Response> {
withOptions options: CallOptions,
onEventLoop eventLoop: EventLoop,
interceptedBy interceptors: [ClientInterceptor<Request, Response>],
onStart: @escaping () -> Void,
onError: @escaping (Error) -> Void,
onResponsePart: @escaping (GRPCClientResponsePart<Response>) -> Void
) -> ClientTransport<Request, Response> {
Expand All @@ -151,6 +152,7 @@ internal struct ClientTransportFactory<Request, Response> {
withOptions: options,
onEventLoop: eventLoop,
interceptedBy: interceptors,
onStart: onStart,
onError: onError,
onResponsePart: onResponsePart
)
Expand Down Expand Up @@ -220,6 +222,7 @@ internal struct HTTP2ClientTransportFactory<Request, Response> {
withOptions options: CallOptions,
onEventLoop eventLoop: EventLoop,
interceptedBy interceptors: [ClientInterceptor<Request, Response>],
onStart: @escaping () -> Void,
onError: @escaping (Error) -> Void,
onResponsePart: @escaping (GRPCClientResponsePart<Response>) -> Void
) -> ClientTransport<Request, Response> {
Expand All @@ -230,6 +233,7 @@ internal struct HTTP2ClientTransportFactory<Request, Response> {
serializer: self.serializer,
deserializer: self.deserializer,
errorDelegate: self.errorDelegate,
onStart: onStart,
onError: onError,
onResponsePart: onResponsePart
)
Expand Down Expand Up @@ -333,6 +337,7 @@ internal struct FakeClientTransportFactory<Request, Response> {
serializer: self.requestSerializer,
deserializer: self.responseDeserializer,
errorDelegate: nil,
onStart: {},
onError: onError,
onResponsePart: onResponsePart
)
Expand Down
Loading