From 302209f0c4b3e3fd0902dab10e357bc30807322b Mon Sep 17 00:00:00 2001 From: Cory Benfield Date: Wed, 12 Aug 2020 15:00:44 +0100 Subject: [PATCH] Avoid unnecessary arrays. Motivation: HTTP1ToGRPCServerCodec currently creates a temporary array for parsing all messages into before it forwards them on. This is both a minor perf drain (due to the extra allocations) and a correctness problem, as it makes this channel handler non-reentrant-safe. We should fix both issues. Modifications: - Replace the temporary array with a simple loop. - Add tests that validates correct behaviour on reentrancy. Result: Better re-entrancy behaviour! Verrrry slightly better perf. --- Sources/GRPC/HTTP1ToGRPCServerCodec.swift | 16 +- .../HTTP1ToGRPCServerCodecTests.swift | 142 ++++++++++++++++++ Tests/GRPCTests/XCTestManifests.swift | 2 + 3 files changed, 152 insertions(+), 8 deletions(-) diff --git a/Sources/GRPC/HTTP1ToGRPCServerCodec.swift b/Sources/GRPC/HTTP1ToGRPCServerCodec.swift index 81cf58788..d7a9b6098 100644 --- a/Sources/GRPC/HTTP1ToGRPCServerCodec.swift +++ b/Sources/GRPC/HTTP1ToGRPCServerCodec.swift @@ -270,10 +270,12 @@ extension HTTP1ToGRPCServerCodec: ChannelInboundHandler { } self.messageReader.append(buffer: &body) - var requests: [ByteBuffer] = [] do { - while let buffer = try self.messageReader.nextMessage() { - requests.append(buffer) + // We may be re-entrantly called, and that re-entrant call may error. If the state changed for any reason, + // stop looping. + while self.inboundState == .expectingBody, + let buffer = try self.messageReader.nextMessage() { + context.fireChannelRead(self.wrapInboundOut(.message(buffer))) } } catch let grpcError as GRPCError.WithContext { context.fireErrorCaught(grpcError) @@ -283,11 +285,9 @@ extension HTTP1ToGRPCServerCodec: ChannelInboundHandler { return .ignore } - requests.forEach { - context.fireChannelRead(self.wrapInboundOut(.message($0))) - } - - return .expectingBody + // We may have been called re-entrantly and transitioned out of the state we were in (e.g. because of an + // error). In all cases, if we get here we want to persist the current state. + return self.inboundState } private func processEnd(context: ChannelHandlerContext, diff --git a/Tests/GRPCTests/HTTP1ToGRPCServerCodecTests.swift b/Tests/GRPCTests/HTTP1ToGRPCServerCodecTests.swift index 6760f47a0..f5fd6d2b0 100644 --- a/Tests/GRPCTests/HTTP1ToGRPCServerCodecTests.swift +++ b/Tests/GRPCTests/HTTP1ToGRPCServerCodecTests.swift @@ -22,6 +22,39 @@ import NIO import NIOHTTP1 import XCTest +/// A trivial channel handler that invokes a callback once, the first time it sees +/// channelRead. +final class OnFirstReadHandler: ChannelInboundHandler { + typealias InboundIn = Any + typealias InboundOut = Any + + private var callback: (() -> Void)? + + init(callback: @escaping () -> Void) { + self.callback = callback + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + context.fireChannelRead(data) + + if let callback = self.callback { + self.callback = nil + callback() + } + } +} + +final class ErrorRecordingHandler: ChannelInboundHandler { + typealias InboundIn = Any + + var errors: [Error] = [] + + func errorCaught(context: ChannelHandlerContext, error: Error) { + self.errors.append(error) + context.fireErrorCaught(error) + } +} + class HTTP1ToGRPCServerCodecTests: GRPCTestCase { var channel: EmbeddedChannel! @@ -127,4 +160,113 @@ class HTTP1ToGRPCServerCodecTests: GRPCTestCase { } } } + + func testReentrantMessageDelivery() throws { + XCTAssertNoThrow( + try self.channel + .writeInbound(HTTPServerRequestPart.head(self.makeRequestHead())) + ) + let requestPart = try self.channel.readInbound(as: _RawGRPCServerRequestPart.self) + + switch requestPart { + case .some(.head): + () + default: + XCTFail("Unexpected request part: \(String(describing: requestPart))") + } + + // Write three messages into a single body. + var buffer = self.channel.allocator.buffer(capacity: 0) + let serializedMessages: [Data] = try ["foo", "bar", "baz"].map { text in + Echo_EchoRequest.with { $0.text = text } + }.map { request in + try request.serializedData() + } + + for data in serializedMessages { + buffer.writeInteger(UInt8(0)) + buffer.writeInteger(UInt32(data.count)) + buffer.writeBytes(data) + } + + // Create an OnFirstReadHandler that will _also_ send the data when it sees the first read. + // This is try! because it cannot throw. + let onFirstRead = OnFirstReadHandler { + try! self.channel.writeInbound(HTTPServerRequestPart.body(buffer)) + } + XCTAssertNoThrow(try self.channel.pipeline.addHandler(onFirstRead).wait()) + + // Now write. + XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.body(buffer))) + + // This must not re-order messages. + for message in [serializedMessages, serializedMessages].flatMap({ $0 }) { + let requestPart = try self.channel.readInbound(as: _RawGRPCServerRequestPart.self) + switch requestPart { + case var .some(.message(buffer)): + XCTAssertEqual(message, buffer.readData(length: buffer.readableBytes)!) + default: + XCTFail("Unexpected request part: \(String(describing: requestPart))") + } + } + } + + func testErrorsOnlyHappenOnce() throws { + XCTAssertNoThrow( + try self.channel + .writeInbound(HTTPServerRequestPart.head(self.makeRequestHead())) + ) + let requestPart = try self.channel.readInbound(as: _RawGRPCServerRequestPart.self) + + switch requestPart { + case .some(.head): + () + default: + XCTFail("Unexpected request part: \(String(describing: requestPart))") + } + + // Write three messages into a single body. + var buffer = self.channel.allocator.buffer(capacity: 0) + let serializedMessages: [Data] = try ["foo", "bar", "baz"].map { text in + Echo_EchoRequest.with { $0.text = text } + }.map { request in + try request.serializedData() + } + + for data in serializedMessages { + buffer.writeInteger(UInt8(0)) + buffer.writeInteger(UInt32(data.count)) + buffer.writeBytes(data) + } + + // Create an OnFirstReadHandler that will _also_ send the data when it sees the first read. + // This is try! because it cannot throw. + let onFirstRead = OnFirstReadHandler { + // Let's create a bad message: we'll turn on compression. We use two bytes here to deal with the fact that + // in hitting the error we'll actually consume the first byte (whoops). + var badBuffer = self.channel.allocator.buffer(capacity: 0) + badBuffer.writeInteger(UInt8(1)) + badBuffer.writeInteger(UInt8(1)) + _ = try? self.channel.writeInbound(HTTPServerRequestPart.body(badBuffer)) + } + let errorHandler = ErrorRecordingHandler() + XCTAssertNoThrow(try self.channel.pipeline.addHandlers([onFirstRead, errorHandler]).wait()) + + // Now write. + XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.body(buffer))) + + // We should have seen the original three messages + for message in serializedMessages { + let requestPart = try self.channel.readInbound(as: _RawGRPCServerRequestPart.self) + switch requestPart { + case var .some(.message(buffer)): + XCTAssertEqual(message, buffer.readData(length: buffer.readableBytes)!) + default: + XCTFail("Unexpected request part: \(String(describing: requestPart))") + } + } + + // We should have recorded only one error. + XCTAssertEqual(errorHandler.errors.count, 1) + } } diff --git a/Tests/GRPCTests/XCTestManifests.swift b/Tests/GRPCTests/XCTestManifests.swift index d1edd0769..54b03e2b0 100644 --- a/Tests/GRPCTests/XCTestManifests.swift +++ b/Tests/GRPCTests/XCTestManifests.swift @@ -646,7 +646,9 @@ extension HTTP1ToGRPCServerCodecTests { // `swift test --generate-linuxmain` // to regenerate. static let __allTests__HTTP1ToGRPCServerCodecTests = [ + ("testErrorsOnlyHappenOnce", testErrorsOnlyHappenOnce), ("testMultipleMessagesFromSingleBodyPart", testMultipleMessagesFromSingleBodyPart), + ("testReentrantMessageDelivery", testReentrantMessageDelivery), ("testSingleMessageFromMultipleBodyParts", testSingleMessageFromMultipleBodyParts), ] }