Skip to content

Commit

Permalink
Enforce request cardinality for unary-request calls also for the case…
Browse files Browse the repository at this point in the history
… of zero request messages being sent. (#392)

Otherwise, the server will never respond to a call that gets closed without the client sending a response.

In addition, we introduce a method `sendErrorStatus` (happy to discuss naming) on `BaseCallHandler` that sends an error status to the client while ensuring that all call context promises are fulfilled. This method is required (and needs to be overridden) because only the concrete call subclass knows which promises need to be fulfilled.
  • Loading branch information
MrMage authored and rebello95 committed Mar 6, 2019
1 parent 140d34a commit d4a6366
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 16 deletions.
16 changes: 13 additions & 3 deletions Sources/SwiftGRPCNIO/CallHandlers/BaseCallHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public class BaseCallHandler<RequestMessage: Message, ResponseMessage: Message>:
/// Called when the client has half-closed the stream, indicating that they won't send any further data.
///
/// Overridden by subclasses if the "end-of-stream" event is relevant.
public func endOfStreamReceived() { }
public func endOfStreamReceived() throws { }

/// Whether this handler can still write messages to the client.
private var serverCanWrite = true
Expand All @@ -30,6 +30,12 @@ public class BaseCallHandler<RequestMessage: Message, ResponseMessage: Message>:
public init(errorDelegate: ServerErrorDelegate?) {
self.errorDelegate = errorDelegate
}

/// Sends an error status to the client while ensuring that all call context promises are fulfilled.
/// Because only the concrete call subclass knows which promises need to be fulfilled, this method needs to be overridden.
func sendErrorStatus(_ status: GRPCStatus) {
fatalError("needs to be overridden")
}
}

extension BaseCallHandler: ChannelInboundHandler {
Expand All @@ -43,7 +49,7 @@ extension BaseCallHandler: ChannelInboundHandler {

let transformed = errorDelegate?.transform(error) ?? error
let status = (transformed as? GRPCStatusTransformable)?.asGRPCStatus() ?? GRPCStatus.processingError
self.write(ctx: ctx, data: NIOAny(GRPCServerResponsePart<ResponseMessage>.status(status)), promise: nil)
sendErrorStatus(status)
}

public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
Expand All @@ -60,7 +66,11 @@ extension BaseCallHandler: ChannelInboundHandler {
}

case .end:
endOfStreamReceived()
do {
try endOfStreamReceived()
} catch {
self.errorCaught(ctx: ctx, error: error)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,13 @@ public class BidirectionalStreamingCallHandler<RequestMessage: Message, Response
}
}

public override func endOfStreamReceived() {
public override func endOfStreamReceived() throws {
eventObserver?.whenSuccess { observer in
observer(.end)
}
}

override func sendErrorStatus(_ status: GRPCStatus) {
context?.statusPromise.fail(error: status)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,13 @@ public class ClientStreamingCallHandler<RequestMessage: Message, ResponseMessage
}
}

public override func endOfStreamReceived() {
public override func endOfStreamReceived() throws {
eventObserver?.whenSuccess { observer in
observer(.end)
}
}

override func sendErrorStatus(_ status: GRPCStatus) {
context?.responsePromise.fail(error: status)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class ServerStreamingCallHandler<RequestMessage: Message, ResponseMessage
public override func processMessage(_ message: RequestMessage) throws {
guard let eventObserver = self.eventObserver,
let context = self.context else {
throw GRPCError.server(.requestCardinalityViolation)
throw GRPCError.server(.tooManyRequests)
}

let resultFuture = eventObserver(message)
Expand All @@ -37,4 +37,14 @@ public class ServerStreamingCallHandler<RequestMessage: Message, ResponseMessage
.cascade(promise: context.statusPromise)
self.eventObserver = nil
}

public override func endOfStreamReceived() throws {
if self.eventObserver != nil {
throw GRPCError.server(.noRequestsButOneExpected)
}
}

override func sendErrorStatus(_ status: GRPCStatus) {
context?.statusPromise.fail(error: status)
}
}
12 changes: 11 additions & 1 deletion Sources/SwiftGRPCNIO/CallHandlers/UnaryCallHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public class UnaryCallHandler<RequestMessage: Message, ResponseMessage: Message>
public override func processMessage(_ message: RequestMessage) throws {
guard let eventObserver = self.eventObserver,
let context = self.context else {
throw GRPCError.server(.requestCardinalityViolation)
throw GRPCError.server(.tooManyRequests)
}

let resultFuture = eventObserver(message)
Expand All @@ -38,4 +38,14 @@ public class UnaryCallHandler<RequestMessage: Message, ResponseMessage: Message>
.cascade(promise: context.responsePromise)
self.eventObserver = nil
}

public override func endOfStreamReceived() throws {
if self.eventObserver != nil {
throw GRPCError.server(.noRequestsButOneExpected)
}
}

override func sendErrorStatus(_ status: GRPCStatus) {
context?.responsePromise.fail(error: status)
}
}
10 changes: 8 additions & 2 deletions Sources/SwiftGRPCNIO/GRPCError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,11 @@ public enum GRPCServerError: Error, Equatable {
/// It was not possible to serialize the response protobuf.
case responseProtoSerializationFailure

/// Zero requests were sent for a unary-request call.
case noRequestsButOneExpected

/// More than one request was sent for a unary-request call.
case requestCardinalityViolation
case tooManyRequests

/// The server received a message when it was not in a writable state.
case serverNotWritable
Expand Down Expand Up @@ -143,7 +146,10 @@ extension GRPCServerError: GRPCStatusTransformable {
case .responseProtoSerializationFailure:
return GRPCStatus(code: .internalError, message: "could not serialize response proto")

case .requestCardinalityViolation:
case .noRequestsButOneExpected:
return GRPCStatus(code: .unimplemented, message: "request cardinality violation; method requires exactly one request but client sent none")

case .tooManyRequests:
return GRPCStatus(code: .unimplemented, message: "request cardinality violation; method requires exactly one request but client sent more")

case .serverNotWritable:
Expand Down
39 changes: 32 additions & 7 deletions Tests/SwiftGRPCNIOTests/NIOServerWebTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class NIOServerWebTests: NIOBasicEchoTestCase {
static var allTests: [(String, (NIOServerWebTests) -> () throws -> Void)] {
return [
("testUnary", testUnary),
("testUnaryWithoutRequestMessage", testUnaryWithoutRequestMessage),
//! FIXME: Broken on Linux: https://github.com/grpc/grpc-swift/issues/382
// ("testUnaryLotsOfRequests", testUnaryLotsOfRequests),
("testServerStreaming", testServerStreaming),
Expand All @@ -43,8 +44,8 @@ class NIOServerWebTests: NIOBasicEchoTestCase {
return data
}

private func gRPCWebOKTrailers() -> Data {
var data = "grpc-status: 0\r\ngrpc-message: OK".data(using: .utf8)!
private func gRPCWebTrailers(status: Int = 0, message: String = "OK") -> Data {
var data = "grpc-status: \(status)\r\ngrpc-message: \(message)".data(using: .utf8)!
// Add the gRPC prefix with the compression byte and the 4 length bytes.
for i in 0..<4 {
data.insert(UInt8((data.count >> (i * 8)) & 0xFF), at: 0)
Expand All @@ -53,13 +54,15 @@ class NIOServerWebTests: NIOBasicEchoTestCase {
return data
}

private func sendOverHTTP1(rpcMethod: String, message: String, handler: @escaping (Data?, Error?) -> Void) {
private func sendOverHTTP1(rpcMethod: String, message: String?, handler: @escaping (Data?, Error?) -> Void) {
let serverURL = URL(string: "http://localhost:5050/echo.Echo/\(rpcMethod)")!
var request = URLRequest(url: serverURL)
request.httpMethod = "POST"
request.setValue("application/grpc-web-text", forHTTPHeaderField: "content-type")

request.httpBody = gRPCEncodedEchoRequest(message).base64EncodedData()
if let message = message {
request.httpBody = gRPCEncodedEchoRequest(message).base64EncodedData()
}

let sem = DispatchSemaphore(value: 0)
URLSession.shared.dataTask(with: request) { (data, response, error) in
Expand All @@ -73,7 +76,7 @@ class NIOServerWebTests: NIOBasicEchoTestCase {
extension NIOServerWebTests {
func testUnary() {
let message = "hello, world!"
let expectedData = gRPCEncodedEchoRequest("Swift echo get: \(message)") + gRPCWebOKTrailers()
let expectedData = gRPCEncodedEchoRequest("Swift echo get: \(message)") + gRPCWebTrailers()
let expectedResponse = expectedData.base64EncodedString()

let completionHandlerExpectation = expectation(description: "completion handler called")
Expand All @@ -83,6 +86,28 @@ extension NIOServerWebTests {
if let data = data {
XCTAssertEqual(String(data: data, encoding: .utf8), expectedResponse)
completionHandlerExpectation.fulfill()
} else {
XCTFail("no data returned")
}
}

waitForExpectations(timeout: defaultTestTimeout)
}

func testUnaryWithoutRequestMessage() {
let expectedData = gRPCWebTrailers(
status: 12, message: "request cardinality violation; method requires exactly one request but client sent none")
let expectedResponse = expectedData.base64EncodedString()

let completionHandlerExpectation = expectation(description: "completion handler called")

sendOverHTTP1(rpcMethod: "Get", message: nil) { data, error in
XCTAssertNil(error)
if let data = data {
XCTAssertEqual(String(data: data, encoding: .utf8), expectedResponse)
completionHandlerExpectation.fulfill()
} else {
XCTFail("no data returned")
}
}

Expand All @@ -104,7 +129,7 @@ extension NIOServerWebTests {

for i in 0..<numberOfRequests {
let message = "foo \(i)"
let expectedData = gRPCEncodedEchoRequest("Swift echo get: \(message)") + gRPCWebOKTrailers()
let expectedData = gRPCEncodedEchoRequest("Swift echo get: \(message)") + gRPCWebTrailers()
let expectedResponse = expectedData.base64EncodedString()
sendOverHTTP1(rpcMethod: "Get", message: message) { data, error in
XCTAssertNil(error)
Expand Down Expand Up @@ -132,7 +157,7 @@ extension NIOServerWebTests {
expectedData.append(gRPCEncodedEchoRequest("Swift echo expand (\(index)): \(component)"))
index += 1
}
expectedData.append(gRPCWebOKTrailers())
expectedData.append(gRPCWebTrailers())
let expectedResponse = expectedData.base64EncodedString()
let completionHandlerExpectation = expectation(description: "completion handler called")

Expand Down

0 comments on commit d4a6366

Please sign in to comment.