Skip to content

Commit

Permalink
Move client serialization into the transport handler
Browse files Browse the repository at this point in the history
Motivation:

Moving the client serialization into transport handler allows us to
get rid of an unspecialized generic handler.

Modifications:

- Add an 'AnySerializer' and 'AnyDeserializer', this is unfortunate but
  necessary since 'ClientTransport' is generic over 'Request' and
  'Response' rather than their respective (de/)serializer. Changing this
  would involve changing the generic constraints on all of the client call
  objects.
- Move (de/)serialization into the 'ClientTransport'
- Add a reverse codec for 'fake' transport

Result:

3.5% fewer instructions in the unary_10k_small_requests benchmark.
  • Loading branch information
glbrntt committed Dec 16, 2020
1 parent ea9c7c5 commit 26579c4
Show file tree
Hide file tree
Showing 7 changed files with 273 additions and 47 deletions.
20 changes: 19 additions & 1 deletion Sources/GRPC/FakeChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,25 @@ public class FakeChannel: GRPCChannel {
)
}

private func _makeCall<Request, Response>(
private func _makeCall<Request: Message, Response: Message>(
path: String,
type: GRPCCallType,
callOptions: CallOptions,
interceptors: [ClientInterceptor<Request, Response>]
) -> Call<Request, Response> {
let stream: _FakeResponseStream<Request, Response>? = self.dequeueResponseStream(forPath: path)
let eventLoop = stream?.channel.eventLoop ?? EmbeddedEventLoop()
return Call(
path: path,
type: type,
eventLoop: eventLoop,
options: callOptions,
interceptors: interceptors,
transportFactory: .fake(stream, on: eventLoop)
)
}

private func _makeCall<Request: GRPCPayload, Response: GRPCPayload>(
path: String,
type: GRPCCallType,
callOptions: CallOptions,
Expand Down
75 changes: 45 additions & 30 deletions Sources/GRPC/Interceptor/ClientTransport.swift
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ internal final class ClientTransport<Request, Response> {
/// A buffer to store request parts and promises in before the channel has become active.
private var writeBuffer = MarkedCircularBuffer<RequestAndPromise>(initialCapacity: 4)

/// The request serializer.
private let serializer: AnySerializer<Request>

/// The response deserializer.
private let deserializer: AnyDeserializer<Response>

/// A request part and a promise.
private struct RequestAndPromise {
var request: GRPCClientRequestPart<Request>
Expand Down Expand Up @@ -102,12 +108,16 @@ internal final class ClientTransport<Request, Response> {
details: CallDetails,
eventLoop: EventLoop,
interceptors: [ClientInterceptor<Request, Response>],
serializer: AnySerializer<Request>,
deserializer: AnyDeserializer<Response>,
errorDelegate: ClientErrorDelegate?,
onError: @escaping (Error) -> Void,
onResponsePart: @escaping (GRPCClientResponsePart<Response>) -> Void
) {
self.eventLoop = eventLoop
self.callDetails = details
self.serializer = serializer
self.deserializer = deserializer
self._pipeline = ClientInterceptorPipeline(
eventLoop: eventLoop,
details: details,
Expand Down Expand Up @@ -236,10 +246,10 @@ extension ClientTransport {

extension ClientTransport: ChannelInboundHandler {
@usableFromInline
typealias InboundIn = _GRPCClientResponsePart<Response>
typealias InboundIn = _RawGRPCClientResponsePart

@usableFromInline
typealias OutboundOut = _GRPCClientRequestPart<Request>
typealias OutboundOut = _RawGRPCClientRequestPart

@usableFromInline
func handlerAdded(context: ChannelHandlerContext) {
Expand Down Expand Up @@ -311,16 +321,32 @@ extension ClientTransport: ChannelInboundHandler {
self.eventLoop.assertInEventLoop()
let part = self.unwrapInboundIn(data)

let isEnd: Bool
switch part {
case .initialMetadata, .message, .trailingMetadata:
isEnd = false
case .status:
isEnd = true
}
case let .initialMetadata(headers):
if self.state.channelRead(isEnd: false) {
self.forwardToInterceptors(.metadata(headers))
}

if self.state.channelRead(isEnd: isEnd) {
self.forwardToInterceptors(part)
case let .message(context):
do {
let message = try self.deserializer.deserialize(byteBuffer: context.message)
if self.state.channelRead(isEnd: false) {
self.forwardToInterceptors(.message(message))
}
} catch {
self.channelError(error)
}

case let .trailingMetadata(trailers):
// The `Channel` delivers trailers and `GRPCStatus` separately, we want to emit them together
// in the interceptor pipeline.
self.trailers = trailers

case let .status(status):
if self.state.channelRead(isEnd: true) {
self.forwardToInterceptors(.end(status, self.trailers ?? [:]))
self.trailers = nil
}
}

// (We're the end of the channel. No need to forward anything.)
Expand Down Expand Up @@ -769,8 +795,13 @@ extension ClientTransport {
context.channel.write(self.wrapOutboundOut(.head(head)), promise: promise)

case let .message(request, metadata):
let message = _MessageContext<Request>(request, compressed: metadata.compress)
context.channel.write(self.wrapOutboundOut(.message(message)), promise: promise)
do {
let bytes = try self.serializer.serialize(request, allocator: context.channel.allocator)
let message = _MessageContext<ByteBuffer>(bytes, compressed: metadata.compress)
context.channel.write(self.wrapOutboundOut(.message(message)), promise: promise)
} catch {
self.channelError(error)
}

case .end:
context.channel.write(self.wrapOutboundOut(.end), promise: promise)
Expand All @@ -783,24 +814,8 @@ extension ClientTransport {

/// Forward the response part to the interceptor pipeline.
/// - Parameter part: The response part to forward.
private func forwardToInterceptors(_ part: _GRPCClientResponsePart<Response>) {
switch part {
case let .initialMetadata(metadata):
self._pipeline?.receive(.metadata(metadata))

case let .message(context):
self._pipeline?.receive(.message(context.message))

case let .trailingMetadata(trailers):
// The `Channel` delivers trailers and `GRPCStatus`, we want to emit them together in the
// interceptor pipeline.
self.trailers = trailers

case let .status(status):
let trailers = self.trailers ?? [:]
self.trailers = nil
self._pipeline?.receive(.end(status, trailers))
}
private func forwardToInterceptors(_ part: GRPCClientResponsePart<Response>) {
self._pipeline?.receive(part)
}

/// Forward the error to the interceptor pipeline.
Expand Down
91 changes: 76 additions & 15 deletions Sources/GRPC/Interceptor/ClientTransportFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ internal struct ClientTransportFactory<Request, Response> {
multiplexer: multiplexer,
scheme: scheme,
authority: authority,
serializer: GRPCPayloadSerializer(),
deserializer: GRPCPayloadDeserializer(),
serializer: AnySerializer(wrapping: GRPCPayloadSerializer()),
deserializer: AnyDeserializer(wrapping: GRPCPayloadDeserializer()),
errorDelegate: errorDelegate
)
return .init(http2)
Expand All @@ -87,11 +87,37 @@ internal struct ClientTransportFactory<Request, Response> {
/// Make a factory for 'fake' transport.
/// - Parameter fakeResponse: The fake response stream.
/// - Returns: A factory for making and configuring fake transport.
internal static func fake(
internal static func fake<Request: SwiftProtobuf.Message, Response: SwiftProtobuf.Message>(
_ fakeResponse: _FakeResponseStream<Request, Response>?,
on eventLoop: EventLoop
) -> ClientTransportFactory<Request, Response> {
return .init(FakeClientTransportFactory(fakeResponse, on: eventLoop))
let factory = FakeClientTransportFactory(
fakeResponse,
on: eventLoop,
requestSerializer: ProtobufSerializer(),
requestDeserializer: ProtobufDeserializer(),
responseSerializer: ProtobufSerializer(),
responseDeserializer: ProtobufDeserializer()
)
return .init(factory)
}

/// Make a factory for 'fake' transport.
/// - Parameter fakeResponse: The fake response stream.
/// - Returns: A factory for making and configuring fake transport.
internal static func fake<Request: GRPCPayload, Response: GRPCPayload>(
_ fakeResponse: _FakeResponseStream<Request, Response>?,
on eventLoop: EventLoop
) -> ClientTransportFactory<Request, Response> {
let factory = FakeClientTransportFactory(
fakeResponse,
on: eventLoop,
requestSerializer: GRPCPayloadSerializer(),
requestDeserializer: GRPCPayloadDeserializer(),
responseSerializer: GRPCPayloadSerializer(),
responseDeserializer: GRPCPayloadDeserializer()
)
return .init(factory)
}

/// Makes a configured `ClientTransport`.
Expand All @@ -103,7 +129,7 @@ internal struct ClientTransportFactory<Request, Response> {
/// - onError: A callback invoked when an error is received.
/// - onResponsePart: A closure called for each response part received.
/// - Returns: A configured transport.
internal func makeConfiguredTransport<Request, Response>(
internal func makeConfiguredTransport(
to path: String,
for type: GRPCCallType,
withOptions options: CallOptions,
Expand Down Expand Up @@ -151,8 +177,11 @@ private struct HTTP2ClientTransportFactory<Request, Response> {
/// An error delegate.
private var errorDelegate: ClientErrorDelegate?

/// A codec for serializing request messages and deserializing response parts.
private var codec: ChannelHandler
/// The request serializer.
private let serializer: AnySerializer<Request>

/// The response deserializer.
private let deserializer: AnyDeserializer<Response>

fileprivate init<Serializer: MessageSerializer, Deserializer: MessageDeserializer>(
multiplexer: EventLoopFuture<HTTP2StreamMultiplexer>,
Expand All @@ -165,11 +194,12 @@ private struct HTTP2ClientTransportFactory<Request, Response> {
self.multiplexer = multiplexer
self.scheme = scheme
self.authority = authority
self.codec = GRPCClientCodecHandler(serializer: serializer, deserializer: deserializer)
self.serializer = AnySerializer(wrapping: serializer)
self.deserializer = AnyDeserializer(wrapping: deserializer)
self.errorDelegate = errorDelegate
}

fileprivate func makeTransport<Request, Response>(
fileprivate func makeTransport(
to path: String,
for type: GRPCCallType,
withOptions options: CallOptions,
Expand All @@ -181,6 +211,8 @@ private struct HTTP2ClientTransportFactory<Request, Response> {
details: self.makeCallDetails(type: type, path: path, options: options),
eventLoop: self.multiplexer.eventLoop,
interceptors: interceptors,
serializer: self.serializer,
deserializer: self.deserializer,
errorDelegate: self.errorDelegate,
onError: onError,
onResponsePart: onResponsePart
Expand All @@ -198,7 +230,6 @@ private struct HTTP2ClientTransportFactory<Request, Response> {
callType: transport.callDetails.type,
logger: transport.logger
),
self.codec,
transport,
])
}
Expand Down Expand Up @@ -233,15 +264,43 @@ private struct FakeClientTransportFactory<Request, Response> {
/// stream be `nil`.
private var eventLoop: EventLoop

fileprivate init(
/// The request serializer.
private let requestSerializer: AnySerializer<Request>

/// The response deserializer.
private let responseDeserializer: AnyDeserializer<Response>

/// A codec for deserializing requests and serializing responses.
private let codec: ChannelHandler

fileprivate init<
RequestSerializer: MessageSerializer,
RequestDeserializer: MessageDeserializer,
ResponseSerializer: MessageSerializer,
ResponseDeserializer: MessageDeserializer
>(
_ fakeResponseStream: _FakeResponseStream<Request, Response>?,
on eventLoop: EventLoop
) {
on eventLoop: EventLoop,
requestSerializer: RequestSerializer,
requestDeserializer: RequestDeserializer,
responseSerializer: ResponseSerializer,
responseDeserializer: ResponseDeserializer
) where RequestSerializer.Input == Request,
RequestDeserializer.Output == Request,
ResponseSerializer.Input == Response,
ResponseDeserializer.Output == Response
{
self.fakeResponseStream = fakeResponseStream
self.eventLoop = eventLoop
self.requestSerializer = AnySerializer(wrapping: requestSerializer)
self.responseDeserializer = AnyDeserializer(wrapping: responseDeserializer)
self.codec = GRPCClientReverseCodecHandler(
serializer: responseSerializer,
deserializer: requestDeserializer
)
}

fileprivate func makeTransport<Request, Response>(
fileprivate func makeTransport(
to path: String,
for type: GRPCCallType,
withOptions options: CallOptions,
Expand All @@ -259,6 +318,8 @@ private struct FakeClientTransportFactory<Request, Response> {
),
eventLoop: self.eventLoop,
interceptors: interceptors,
serializer: self.requestSerializer,
deserializer: self.responseDeserializer,
errorDelegate: nil,
onError: onError,
onResponsePart: onResponsePart
Expand All @@ -268,7 +329,7 @@ private struct FakeClientTransportFactory<Request, Response> {
fileprivate func configure<Request, Response>(_ transport: ClientTransport<Request, Response>) {
transport.configure { handler in
if let fakeResponse = self.fakeResponseStream {
return fakeResponse.channel.pipeline.addHandler(handler).always { result in
return fakeResponse.channel.pipeline.addHandlers(self.codec, handler).always { result in
switch result {
case .success:
fakeResponse.activate()
Expand Down
28 changes: 28 additions & 0 deletions Sources/GRPC/Serialization.swift
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,31 @@ public struct GRPCPayloadDeserializer<Message: GRPCPayload>: MessageDeserializer
return try Message(serializedByteBuffer: &buffer)
}
}

// MARK: - Any Serializer/Deserializer

internal struct AnySerializer<Input>: MessageSerializer {
private let _serialize: (Input, ByteBufferAllocator) throws -> ByteBuffer

init<Serializer: MessageSerializer>(wrapping other: Serializer) where Serializer.Input == Input {
self._serialize = other.serialize(_:allocator:)
}

internal func serialize(_ input: Input, allocator: ByteBufferAllocator) throws -> ByteBuffer {
return try self._serialize(input, allocator)
}
}

internal struct AnyDeserializer<Output>: MessageDeserializer {
private let _deserialize: (ByteBuffer) throws -> Output

init<Deserializer: MessageDeserializer>(
wrapping other: Deserializer
) where Deserializer.Output == Output {
self._deserialize = other.deserialize(byteBuffer:)
}

internal func deserialize(byteBuffer: ByteBuffer) throws -> Output {
return try self._deserialize(byteBuffer)
}
}
Loading

0 comments on commit 26579c4

Please sign in to comment.