Skip to content

Commit

Permalink
Rewrite LengthPrefixedMessageReader and tests (#397)
Browse files Browse the repository at this point in the history
* Rewrite LengthPrefixedMessageReader, add tests

* Switch order of expected and actual in LengthPrefixedMessageReaderTests
  • Loading branch information
glbrntt authored and MrMage committed Mar 15, 2019
1 parent 869c168 commit 143255e
Show file tree
Hide file tree
Showing 6 changed files with 355 additions and 102 deletions.
4 changes: 2 additions & 2 deletions Sources/SwiftGRPCNIO/CompressionMechanism.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public enum CompressionMechanism: String {
/// Whether the compression flag in gRPC length-prefixed messages should be set or not.
///
/// See `LengthPrefixedMessageReader` for the message format.
var requiresFlag: Bool {
public var requiresFlag: Bool {
switch self {
case .none:
return false
Expand All @@ -51,7 +51,7 @@ public enum CompressionMechanism: String {
}

/// Whether the given compression is supported.
var supported: Bool {
public var supported: Bool {
switch self {
case .identity, .none:
return true
Expand Down
13 changes: 8 additions & 5 deletions Sources/SwiftGRPCNIO/HTTP1ToRawGRPCClientCodec.swift
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public final class HTTP1ToRawGRPCClientCodec {
}

private var state: State = .expectingHeaders
private let messageReader = LengthPrefixedMessageReader(mode: .client)
private let messageReader = LengthPrefixedMessageReader(mode: .client, compressionMechanism: .none)
private let messageWriter = LengthPrefixedMessageWriter()
private var inboundCompression: CompressionMechanism = .none
}
Expand Down Expand Up @@ -93,14 +93,16 @@ extension HTTP1ToRawGRPCClientCodec: ChannelInboundHandler {
throw GRPCError.client(.HTTPStatusNotOk(head.status))
}

if let encodingType = head.headers["grpc-encoding"].first {
self.inboundCompression = CompressionMechanism(rawValue: encodingType) ?? .unknown
}
let inboundCompression: CompressionMechanism = head.headers["grpc-encoding"]
.first
.map { CompressionMechanism(rawValue: $0) ?? .unknown } ?? .none

guard inboundCompression.supported else {
throw GRPCError.client(.unsupportedCompressionMechanism(inboundCompression.rawValue))
}

self.messageReader.compressionMechanism = inboundCompression

ctx.fireChannelRead(self.wrapInboundOut(.headers(head.headers)))
return .expectingBodyOrTrailers
}
Expand All @@ -114,7 +116,8 @@ extension HTTP1ToRawGRPCClientCodec: ChannelInboundHandler {
throw GRPCError.client(.invalidState("received body while in state \(state)"))
}

for message in try self.messageReader.consume(messageBuffer: &messageBuffer, compression: inboundCompression) {
self.messageReader.append(buffer: &messageBuffer)
while let message = try self.messageReader.nextMessage() {
ctx.fireChannelRead(self.wrapInboundOut(.message(message)))
}

Expand Down
5 changes: 3 additions & 2 deletions Sources/SwiftGRPCNIO/HTTP1ToRawGRPCServerCodec.swift
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public final class HTTP1ToRawGRPCServerCodec {
var outboundState = OutboundState.expectingHeaders

var messageWriter = LengthPrefixedMessageWriter()
var messageReader = LengthPrefixedMessageReader(mode: .server)
var messageReader = LengthPrefixedMessageReader(mode: .server, compressionMechanism: .none)
}

extension HTTP1ToRawGRPCServerCodec {
Expand Down Expand Up @@ -148,7 +148,8 @@ extension HTTP1ToRawGRPCServerCodec: ChannelInboundHandler {
body.write(bytes: decodedData)
}

for message in try messageReader.consume(messageBuffer: &body, compression: .none) {
self.messageReader.append(buffer: &body)
while let message = try self.messageReader.nextMessage() {
ctx.fireChannelRead(self.wrapInboundOut(.message(message)))
}

Expand Down
181 changes: 89 additions & 92 deletions Sources/SwiftGRPCNIO/LengthPrefixedMessageReader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
*/
import Foundation
import NIO
import NIOHTTP1

/// This class reads and decodes length-prefixed gRPC messages.
///
Expand All @@ -32,117 +31,115 @@ import NIOHTTP1
public class LengthPrefixedMessageReader {
public typealias Mode = GRPCError.Origin

private let mode: Mode
private var buffer: ByteBuffer!
private var state: State = .expectingCompressedFlag
/// The mechanism that messages will be compressed with.
public var compressionMechanism: CompressionMechanism

private enum State {
case expectingCompressedFlag
case expectingMessageLength
case receivedMessageLength(Int)
case willBuffer(requiredBytes: Int)
case isBuffering(requiredBytes: Int)
}

public init(mode: Mode) {
public init(mode: Mode, compressionMechanism: CompressionMechanism) {
self.mode = mode
self.compressionMechanism = compressionMechanism
}

/// Consumes all readable bytes from given buffer and returns all messages which could be read.
/// The result of trying to parse a message with the bytes we currently have.
///
/// - SeeAlso: `read(messageBuffer:compression:)`
public func consume(messageBuffer: inout ByteBuffer, compression: CompressionMechanism) throws -> [ByteBuffer] {
var messages: [ByteBuffer] = []
/// - needMoreData: More data is required to continue reading a message.
/// - continue: Continue reading a message.
/// - message: A message was read.
internal enum ParseResult {
case needMoreData
case `continue`
case message(ByteBuffer)
}

while messageBuffer.readableBytes > 0 {
if let message = try self.read(messageBuffer: &messageBuffer, compression: compression) {
messages.append(message)
}
}
/// The parsing state; what we expect to be reading next.
internal enum ParseState {
case expectingCompressedFlag
case expectingMessageLength
case expectingMessage(UInt32)
}

return messages
private let mode: Mode
private var buffer: ByteBuffer!
private var state: ParseState = .expectingCompressedFlag

/// Appends data to the buffer from which messages will be read.
public func append(buffer: inout ByteBuffer) {
if self.buffer == nil {
self.buffer = buffer.slice()
// mark the bytes as "read"
buffer.moveReaderIndex(forwardBy: buffer.readableBytes)
} else {
self.buffer.write(buffer: &buffer)
}
}

/// Reads bytes from the given buffer until it is exhausted or a message has been read.
///
/// Length prefixed messages may be split across multiple input buffers in any of the
/// following places:
/// 1. after the compression flag,
/// 2. after the message length field,
/// 3. at any point within the message.
///
/// It is possible for the message length field to be split across multiple `ByteBuffer`s,
/// this is unlikely to happen in practice.
/// Reads bytes from the buffer until it is exhausted or a message has been read.
///
/// - Note:
/// This method relies on state; if a message is _not_ returned then the next time this
/// method is called it expects to read the bytes which follow the most recently read bytes.
///
/// - Parameters:
/// - messageBuffer: buffer to read from.
/// - compression: compression mechanism to decode message with.
/// - Returns: A buffer containing a message if one has been read, or `nil` if not enough
/// bytes have been consumed to return a message.
/// - Throws: Throws an error if the compression algorithm is not supported.
public func read(messageBuffer: inout ByteBuffer, compression: CompressionMechanism) throws -> ByteBuffer? {
while true {
switch state {
case .expectingCompressedFlag:
guard let compressionFlag: Int8 = messageBuffer.readInteger() else { return nil }
try handleCompressionFlag(enabled: compressionFlag != 0, mechanism: compression)
self.state = .expectingMessageLength

case .expectingMessageLength:
//! FIXME: Support the message length being split across multiple byte buffers.
guard let messageLength: UInt32 = messageBuffer.readInteger() else { return nil }
self.state = .receivedMessageLength(numericCast(messageLength))

case .receivedMessageLength(let messageLength):
// If this holds true, we can skip buffering and return a slice.
guard messageLength <= messageBuffer.readableBytes else {
self.state = .willBuffer(requiredBytes: messageLength)
continue
}

self.state = .expectingCompressedFlag
// We know messageBuffer.readableBytes >= messageLength, so it's okay to force unwrap here.
return messageBuffer.readSlice(length: messageLength)!

case .willBuffer(let requiredBytes):
messageBuffer.reserveCapacity(requiredBytes)
self.buffer = messageBuffer

let readableBytes = messageBuffer.readableBytes
// Move the reader index to avoid reading the bytes again.
messageBuffer.moveReaderIndex(forwardBy: readableBytes)

self.state = .isBuffering(requiredBytes: requiredBytes - readableBytes)
return nil

case .isBuffering(let requiredBytes):
guard requiredBytes <= messageBuffer.readableBytes else {
self.state = .isBuffering(requiredBytes: requiredBytes - self.buffer.write(buffer: &messageBuffer))
return nil
}

// We know messageBuffer.readableBytes >= requiredBytes, so it's okay to force unwrap here.
var slice = messageBuffer.readSlice(length: requiredBytes)!
self.buffer.write(buffer: &slice)
self.state = .expectingCompressedFlag

defer { self.buffer = nil }
return buffer
public func nextMessage() throws -> ByteBuffer? {
switch try self.processNextState() {
case .needMoreData:
self.nilBufferIfPossible()
return nil

case .continue:
return try nextMessage()

case .message(let message):
self.nilBufferIfPossible()
return message
}
}

/// `nil`s out `buffer` if it exists and has no readable bytes.
///
/// This allows the next call to `append` to avoid writing the contents of the appended buffer.
private func nilBufferIfPossible() {
if self.buffer?.readableBytes == 0 {
self.buffer = nil
}
}

private func processNextState() throws -> ParseResult {
guard self.buffer != nil else { return .needMoreData }

switch self.state {
case .expectingCompressedFlag:
guard let compressionFlag: Int8 = self.buffer.readInteger() else {
return .needMoreData
}
try self.handleCompressionFlag(enabled: compressionFlag != 0)
self.state = .expectingMessageLength

case .expectingMessageLength:
guard let messageLength: UInt32 = self.buffer.readInteger() else {
return .needMoreData
}
self.state = .expectingMessage(messageLength)

case .expectingMessage(let length):
guard let message = self.buffer.readSlice(length: numericCast(length)) else {
return .needMoreData
}
self.state = .expectingCompressedFlag
return .message(message)
}

return .continue
}

private func handleCompressionFlag(enabled flagEnabled: Bool, mechanism: CompressionMechanism) throws {
guard flagEnabled == mechanism.requiresFlag else {
private func handleCompressionFlag(enabled flagEnabled: Bool) throws {
guard flagEnabled else {
return
}

guard self.compressionMechanism.requiresFlag else {
throw GRPCError.common(.unexpectedCompression, origin: mode)
}

guard mechanism.supported else {
throw GRPCError.common(.unsupportedCompressionMechanism(mechanism.rawValue), origin: mode)
guard self.compressionMechanism.supported else {
throw GRPCError.common(.unsupportedCompressionMechanism(compressionMechanism.rawValue), origin: mode)
}
}
}
3 changes: 2 additions & 1 deletion Tests/LinuxMain.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,6 @@ XCTMain([
testCase(NIOClientTimeoutTests.allTests),
testCase(NIOServerWebTests.allTests),
testCase(GRPCChannelHandlerTests.allTests),
testCase(HTTP1ToRawGRPCServerCodecTests.allTests)
testCase(HTTP1ToRawGRPCServerCodecTests.allTests),
testCase(LengthPrefixedMessageReaderTests.allTests),
])
Loading

0 comments on commit 143255e

Please sign in to comment.