diff --git a/Sources/NIOHTTP2/HTTP2CommonInboundStreamMultiplexer.swift b/Sources/NIOHTTP2/HTTP2CommonInboundStreamMultiplexer.swift index 24d70f4c..0aa8a3fa 100644 --- a/Sources/NIOHTTP2/HTTP2CommonInboundStreamMultiplexer.swift +++ b/Sources/NIOHTTP2/HTTP2CommonInboundStreamMultiplexer.swift @@ -489,10 +489,12 @@ extension NIOHTTP2AsyncSequence { switch yieldResult { case .enqueued: break // success, nothing to do + case .terminated: + // this can happen if the task has been cancelled + // we can't do better than dropping the message at the moment + break case .dropped: preconditionFailure("Attempted to yield when AsyncThrowingStream is over capacity. This shouldn't be possible for an unbounded stream.") - case .terminated: - preconditionFailure("Attempted to yield to AsyncThrowingStream in terminated state.") default: preconditionFailure("Attempt to yield to AsyncThrowingStream failed for unhandled reason.") } diff --git a/Tests/NIOHTTP2Tests/ConfiguringPipelineAsyncMultiplexerTests.swift b/Tests/NIOHTTP2Tests/ConfiguringPipelineAsyncMultiplexerTests.swift index f3208943..660b95de 100644 --- a/Tests/NIOHTTP2Tests/ConfiguringPipelineAsyncMultiplexerTests.swift +++ b/Tests/NIOHTTP2Tests/ConfiguringPipelineAsyncMultiplexerTests.swift @@ -92,6 +92,70 @@ final class ConfiguringPipelineAsyncMultiplexerTests: XCTestCase { } } + func testCancellingAsyncStreamConsumer() async throws { + let requestCount = 200 + + let serverRecorder = InboundFramePayloadRecorder() + + let clientMultiplexer = try await assertNoThrowWithValue( + try await self.clientChannel.configureAsyncHTTP2Pipeline(mode: .client) { channel -> EventLoopFuture in + channel.eventLoop.makeSucceededFuture(channel) + }.get() + ) + + let serverMultiplexer = try await assertNoThrowWithValue( + try await self.serverChannel.configureAsyncHTTP2Pipeline(mode: .server) { channel -> EventLoopFuture in + channel.pipeline.addHandlers([OKResponder(), serverRecorder]).map { _ in channel } + }.get() + ) + + try await assertNoThrow(try await self.assertDoHandshake(client: self.clientChannel, server: self.serverChannel)) + + // Launch a server + let serverTask = Task { + var serverInboundChannelCount = 0 + for try await _ in serverMultiplexer.inbound { + serverInboundChannelCount += 1 + } + + try Task.checkCancellation() + + return serverInboundChannelCount + } + + // client + for i in 0 ..< requestCount { + // Let's try sending some requests. + let streamChannel = try await clientMultiplexer.openStream { channel -> EventLoopFuture in + return channel.pipeline.addHandlers([SimpleRequest(), InboundFramePayloadRecorder()]).map { + return channel + } + } + + // When we get above 100, cancel the server task. + if i == 100 { serverTask.cancel() } + + let clientRecorder = try await streamChannel.pipeline.handler(type: InboundFramePayloadRecorder.self).get() + try await Self.deliverAllBytes(from: self.clientChannel, to: self.serverChannel) + try await Self.deliverAllBytes(from: self.serverChannel, to: self.clientChannel) + + clientRecorder.receivedFrames.assertFramePayloadsMatch([ConfiguringPipelineAsyncMultiplexerTests.responseFramePayload]) + try await streamChannel.closeFuture.get() + } + + try await assertNoThrow(try await self.clientChannel.finish()) + try await assertNoThrow(try await self.serverChannel.finish()) + + do { + _ = try await serverTask.value + XCTFail("Server unexpectedly succeeded") + } catch is CancellationError { + // Expected + } catch { + XCTFail("Unexpected error throw: \(error)") + } + } + // `testBasicPipelineCommunicates` ensures that a client-server system set up to use async stream abstractions // can communicate successfully. func testBasicPipelineCommunicates() async throws {