diff --git a/Sources/NIOExtras/NIORequestIdentifiable.swift b/Sources/NIOExtras/NIORequestIdentifiable.swift new file mode 100644 index 00000000..6b5ba8c6 --- /dev/null +++ b/Sources/NIOExtras/NIORequestIdentifiable.swift @@ -0,0 +1,19 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2022 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +public protocol NIORequestIdentifiable { + associatedtype RequestID: Hashable + + var requestID: RequestID { get } +} diff --git a/Sources/NIOExtras/TaggedRequestResponseHandler.swift b/Sources/NIOExtras/TaggedRequestResponseHandler.swift new file mode 100644 index 00000000..db0e675b --- /dev/null +++ b/Sources/NIOExtras/TaggedRequestResponseHandler.swift @@ -0,0 +1,130 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2017-2022 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIOCore + +/// `NIOTaggedRequestResponseHandler` receives a `Request` alongside an `EventLoopPromise` from the +/// `Channel`'s outbound side. It will fulfill the promise with the `Response` once it's received from the `Channel`'s +/// inbound side. Requests and responses can arrive out-of-order and are matched by the virtue of being +/// `NIORequestIdentifiable`. +/// +/// `NIOTaggedRequestResponseHandler` does support pipelining `Request`s and it will send them pipelined further down the +/// `Channel`. Should `RequestResponseHandler` receive an error from the `Channel`, it will fail all promises meant for +/// the outstanding `Reponse`s and close the `Channel`. All requests enqueued after an error occured will be immediately +/// failed with the first error the channel received. +/// +/// `NIOTaggedRequestResponseHandler` does _not_ require that the `Response`s arrive on `Channel` in the same order as +/// the `Request`s were submitted. They are matched by their `requestID` property (from `NIORequestIdentifiable`). +public final class NIOTaggedRequestResponseHandler: ChannelDuplexHandler + where Request.RequestID == Response.RequestID { + public typealias InboundIn = Response + public typealias InboundOut = Never + public typealias OutboundIn = (Request, EventLoopPromise) + public typealias OutboundOut = Request + + private enum State { + case operational + case error(Error) + + var isOperational: Bool { + switch self { + case .operational: + return true + case .error: + return false + } + } + } + + private var state: State = .operational + private var promiseBuffer: [Request.RequestID: EventLoopPromise] + + + /// Create a new `RequestResponseHandler`. + /// + /// - parameters: + /// - initialBufferCapacity: `RequestResponseHandler` saves the promises for all outstanding responses in a + /// buffer. `initialBufferCapacity` is the initial capacity for this buffer. You usually do not need to set + /// this parameter unless you intend to pipeline very deeply and don't want the buffer to resize. + public init(initialBufferCapacity: Int = 4) { + self.promiseBuffer = [:] + self.promiseBuffer.reserveCapacity(initialBufferCapacity) + } + + public func channelInactive(context: ChannelHandlerContext) { + switch self.state { + case .error: + // We failed any outstanding promises when we entered the error state and will fail any + // new promises in write. + assert(self.promiseBuffer.count == 0) + case .operational: + let promiseBuffer = self.promiseBuffer + self.promiseBuffer.removeAll() + promiseBuffer.forEach { promise in + promise.value.fail(NIOExtrasErrors.ClosedBeforeReceivingResponse()) + } + } + context.fireChannelInactive() + } + + public func channelRead(context: ChannelHandlerContext, data: NIOAny) { + guard self.state.isOperational else { + // we're in an error state, ignore further responses + assert(self.promiseBuffer.count == 0) + return + } + + let response = self.unwrapInboundIn(data) + if let promise = self.promiseBuffer.removeValue(forKey: response.requestID) { + promise.succeed(response) + } else { + context.fireErrorCaught(NIOExtrasErrors.ResponseForInvalidRequest(id: response.requestID)) + } + } + + public func errorCaught(context: ChannelHandlerContext, error: Error) { + guard self.state.isOperational else { + assert(self.promiseBuffer.count == 0) + return + } + self.state = .error(error) + let promiseBuffer = self.promiseBuffer + self.promiseBuffer.removeAll() + context.close(promise: nil) + promiseBuffer.forEach { + $0.value.fail(error) + } + } + + public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + let (request, responsePromise) = self.unwrapOutboundIn(data) + switch self.state { + case .error(let error): + assert(self.promiseBuffer.count == 0) + responsePromise.fail(error) + promise?.fail(error) + case .operational: + self.promiseBuffer[request.requestID] = responsePromise + context.write(self.wrapOutboundOut(request), promise: promise) + } + } +} + +extension NIOExtrasErrors { + public struct ResponseForInvalidRequest: NIOExtrasError, Equatable { + var id: Response.RequestID + } +} +