diff --git a/Sources/GRPC/CallHandlers/BidirectionalStreamingCallHandler.swift b/Sources/GRPC/CallHandlers/BidirectionalStreamingCallHandler.swift index 97826a73e..635e1d716 100644 --- a/Sources/GRPC/CallHandlers/BidirectionalStreamingCallHandler.swift +++ b/Sources/GRPC/CallHandlers/BidirectionalStreamingCallHandler.swift @@ -132,6 +132,7 @@ public class BidirectionalStreamingCallHandler< eventLoop: self.eventLoop, headers: headers, logger: self.logger, + userInfoRef: self.userInfoRef, sendResponse: self.sendResponse(_:metadata:promise:) ) let observer = factory(context) diff --git a/Sources/GRPC/CallHandlers/ClientStreamingCallHandler.swift b/Sources/GRPC/CallHandlers/ClientStreamingCallHandler.swift index 95d2a0168..6d2421650 100644 --- a/Sources/GRPC/CallHandlers/ClientStreamingCallHandler.swift +++ b/Sources/GRPC/CallHandlers/ClientStreamingCallHandler.swift @@ -132,7 +132,8 @@ public final class ClientStreamingCallHandler< let context = UnaryResponseCallContext( eventLoop: self.eventLoop, headers: headers, - logger: self.logger + logger: self.logger, + userInfoRef: self.userInfoRef ) let observer = factory(context) diff --git a/Sources/GRPC/CallHandlers/ServerStreamingCallHandler.swift b/Sources/GRPC/CallHandlers/ServerStreamingCallHandler.swift index f3e524f55..14045cfc2 100644 --- a/Sources/GRPC/CallHandlers/ServerStreamingCallHandler.swift +++ b/Sources/GRPC/CallHandlers/ServerStreamingCallHandler.swift @@ -137,6 +137,7 @@ public final class ServerStreamingCallHandler< eventLoop: self.eventLoop, headers: headers, logger: self.logger, + userInfoRef: self.userInfoRef, sendResponse: self.sendResponse(_:metadata:promise:) ) let observer = factory(context) diff --git a/Sources/GRPC/CallHandlers/UnaryCallHandler.swift b/Sources/GRPC/CallHandlers/UnaryCallHandler.swift index 72ee5a6c7..b022f678d 100644 --- a/Sources/GRPC/CallHandlers/UnaryCallHandler.swift +++ b/Sources/GRPC/CallHandlers/UnaryCallHandler.swift @@ -163,7 +163,8 @@ public final class UnaryCallHandler< let context = UnaryResponseCallContext( eventLoop: self.eventLoop, headers: headers, - logger: self.logger + logger: self.logger, + userInfoRef: self.userInfoRef ) let observer = factory(context) diff --git a/Sources/GRPC/CallHandlers/_BaseCallHandler.swift b/Sources/GRPC/CallHandlers/_BaseCallHandler.swift index e5da7ffc3..9a25464cd 100644 --- a/Sources/GRPC/CallHandlers/_BaseCallHandler.swift +++ b/Sources/GRPC/CallHandlers/_BaseCallHandler.swift @@ -56,20 +56,27 @@ public class _BaseCallHandler: GRPCCallHandler, ChannelInboun return self.callHandlerContext.logger } + /// A reference to `UserInfo`. + internal var userInfoRef: Ref + internal init( callHandlerContext: CallHandlerContext, codec: ChannelHandler, callType: GRPCCallType, interceptors: [ServerInterceptor] ) { + let userInfoRef = Ref(UserInfo()) + self.callHandlerContext = callHandlerContext self._codec = codec self.callType = callType + self.userInfoRef = userInfoRef self.pipeline = ServerInterceptorPipeline( logger: callHandlerContext.logger, eventLoop: callHandlerContext.eventLoop, path: callHandlerContext.path, callType: callType, + userInfoRef: userInfoRef, interceptors: interceptors, onRequestPart: self.receiveRequestPartFromInterceptors(_:), onResponsePart: self.sendResponsePartFromInterceptors(_:promise:) diff --git a/Sources/GRPC/Interceptor/ServerInterceptorContext.swift b/Sources/GRPC/Interceptor/ServerInterceptorContext.swift index 45457363b..4e1c2e78e 100644 --- a/Sources/GRPC/Interceptor/ServerInterceptorContext.swift +++ b/Sources/GRPC/Interceptor/ServerInterceptorContext.swift @@ -56,6 +56,21 @@ public struct ServerInterceptorContext { return self.pipeline.path } + /// A 'UserInfo' dictionary. + /// + /// - Important: While `UserInfo` has value-semantics, this property retrieves from, and sets a + /// reference wrapped `UserInfo`. The contexts passed to the service provider share the same + /// reference. As such this may be used as a mechanism to pass information between interceptors + /// and service providers. + public var userInfo: UserInfo { + get { + return self.pipeline.userInfoRef.value + } + nonmutating set { + self.pipeline.userInfoRef.value = newValue + } + } + /// Construct a `ServerInterceptorContext` for the interceptor at the given index within the /// interceptor pipeline. internal init( diff --git a/Sources/GRPC/Interceptor/ServerInterceptorPipeline.swift b/Sources/GRPC/Interceptor/ServerInterceptorPipeline.swift index 3a7b08f7c..04144e1ce 100644 --- a/Sources/GRPC/Interceptor/ServerInterceptorPipeline.swift +++ b/Sources/GRPC/Interceptor/ServerInterceptorPipeline.swift @@ -29,6 +29,9 @@ internal final class ServerInterceptorPipeline { /// A logger. internal let logger: Logger + /// A reference to a 'UserInfo'. + internal let userInfoRef: Ref + /// The contexts associated with the interceptors stored in this pipeline. Contexts will be /// removed once the RPC has completed. Contexts are ordered from inbound to outbound, that is, /// the head is first and the tail is last. @@ -80,6 +83,7 @@ internal final class ServerInterceptorPipeline { eventLoop: EventLoop, path: String, callType: GRPCCallType, + userInfoRef: Ref, interceptors: [ServerInterceptor], onRequestPart: @escaping (ServerRequestPart) -> Void, onResponsePart: @escaping (ServerResponsePart, EventLoopPromise?) -> Void @@ -88,6 +92,7 @@ internal final class ServerInterceptorPipeline { self.eventLoop = eventLoop self.path = path self.type = callType + self.userInfoRef = userInfoRef // We need space for the head and tail as well as any user provided interceptors. var contexts: [ServerInterceptorContext] = [] diff --git a/Sources/GRPC/Ref.swift b/Sources/GRPC/Ref.swift new file mode 100644 index 000000000..ca1fe706e --- /dev/null +++ b/Sources/GRPC/Ref.swift @@ -0,0 +1,22 @@ +/* + * Copyright 2020, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +internal final class Ref { + internal var value: Value + internal init(_ value: Value) { + self.value = value + } +} diff --git a/Sources/GRPC/ServerCallContexts/ServerCallContext.swift b/Sources/GRPC/ServerCallContexts/ServerCallContext.swift index af9c19d32..e69663c02 100644 --- a/Sources/GRPC/ServerCallContexts/ServerCallContext.swift +++ b/Sources/GRPC/ServerCallContexts/ServerCallContext.swift @@ -28,6 +28,9 @@ public protocol ServerCallContext: AnyObject { /// Request headers for this request. var headers: HPACKHeaders { get } + /// A 'UserInfo' dictionary. + var userInfo: UserInfo { get set } + /// The logger used for this call. var logger: Logger { get } @@ -44,21 +47,53 @@ open class ServerCallContextBase: ServerCallContext { public let logger: Logger public var compressionEnabled: Bool = true + /// - Important: While `UserInfo` has value-semantics, this property retrieves from, and sets a + /// reference wrapped `UserInfo`. The contexts passed to interceptors provide the same + /// reference. As such this may be used as a mechanism to pass information between interceptors + /// and service providers. + public var userInfo: UserInfo { + get { + return self.userInfoRef.value + } + set { + self.userInfoRef.value = newValue + } + } + + /// A reference to an underlying `UserInfo`. We share this with the interceptors. + private let userInfoRef: Ref + /// Metadata to return at the end of the RPC. If this is required it should be updated before /// the `responsePromise` or `statusPromise` is fulfilled. public var trailers = HPACKHeaders() - public init(eventLoop: EventLoop, headers: HPACKHeaders, logger: Logger) { + public convenience init( + eventLoop: EventLoop, + headers: HPACKHeaders, + logger: Logger, + userInfo: UserInfo = UserInfo() + ) { + self.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: .init(userInfo)) + } + + internal init( + eventLoop: EventLoop, + headers: HPACKHeaders, + logger: Logger, + userInfoRef: Ref + ) { self.eventLoop = eventLoop self.headers = headers + self.userInfoRef = userInfoRef self.logger = logger } - @available(*, deprecated, renamed: "init(eventLoop:headers:logger:)") + @available(*, deprecated, renamed: "init(eventLoop:headers:logger:userInfo:)") public init(eventLoop: EventLoop, request: HTTPRequestHead, logger: Logger) { self.eventLoop = eventLoop self.headers = HPACKHeaders(httpHeaders: request.headers, normalizeHTTPHeaders: false) self.logger = logger + self.userInfoRef = .init(UserInfo()) } /// Processes an error, transforming it into a 'GRPCStatus' and any trailers to send to the peer. diff --git a/Sources/GRPC/ServerCallContexts/StreamingResponseCallContext.swift b/Sources/GRPC/ServerCallContexts/StreamingResponseCallContext.swift index dcaa2fe3c..b4dfa4cc3 100644 --- a/Sources/GRPC/ServerCallContexts/StreamingResponseCallContext.swift +++ b/Sources/GRPC/ServerCallContexts/StreamingResponseCallContext.swift @@ -31,12 +31,26 @@ open class StreamingResponseCallContext: ServerCallContextBase public let statusPromise: EventLoopPromise - override public init(eventLoop: EventLoop, headers: HPACKHeaders, logger: Logger) { + public convenience init( + eventLoop: EventLoop, + headers: HPACKHeaders, + logger: Logger, + userInfo: UserInfo = UserInfo() + ) { + self.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: .init(userInfo)) + } + + override internal init( + eventLoop: EventLoop, + headers: HPACKHeaders, + logger: Logger, + userInfoRef: Ref + ) { self.statusPromise = eventLoop.makePromise() - super.init(eventLoop: eventLoop, headers: headers, logger: logger) + super.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: userInfoRef) } - @available(*, deprecated, renamed: "init(eventLoop:path:headers:logger:)") + @available(*, deprecated, renamed: "init(eventLoop:path:headers:logger:userInfo:)") override public init(eventLoop: EventLoop, request: HTTPRequestHead, logger: Logger) { self.statusPromise = eventLoop.makePromise() super.init(eventLoop: eventLoop, request: request, logger: logger) @@ -113,10 +127,11 @@ internal final class _StreamingResponseCallContext: eventLoop: EventLoop, headers: HPACKHeaders, logger: Logger, + userInfoRef: Ref, sendResponse: @escaping (Response, MessageMetadata, EventLoopPromise?) -> Void ) { self._sendResponse = sendResponse - super.init(eventLoop: eventLoop, headers: headers, logger: logger) + super.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: userInfoRef) } override func sendResponse( @@ -165,7 +180,12 @@ open class StreamingResponseCallContextImpl: StreamingResponseC logger: Logger ) { self.channel = channel - super.init(eventLoop: channel.eventLoop, headers: headers, logger: logger) + super.init( + eventLoop: channel.eventLoop, + headers: headers, + logger: logger, + userInfoRef: Ref(UserInfo()) + ) self.statusPromise.futureResult.whenComplete { result in switch result { diff --git a/Sources/GRPC/ServerCallContexts/UnaryResponseCallContext.swift b/Sources/GRPC/ServerCallContexts/UnaryResponseCallContext.swift index 76e5268d0..3121d401f 100644 --- a/Sources/GRPC/ServerCallContexts/UnaryResponseCallContext.swift +++ b/Sources/GRPC/ServerCallContexts/UnaryResponseCallContext.swift @@ -35,12 +35,26 @@ open class UnaryResponseCallContext: ServerCallContextBase, Sta public let responsePromise: EventLoopPromise public var responseStatus: GRPCStatus = .ok - override public init(eventLoop: EventLoop, headers: HPACKHeaders, logger: Logger) { + public convenience init( + eventLoop: EventLoop, + headers: HPACKHeaders, + logger: Logger, + userInfo: UserInfo = UserInfo() + ) { + self.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: .init(userInfo)) + } + + override internal init( + eventLoop: EventLoop, + headers: HPACKHeaders, + logger: Logger, + userInfoRef: Ref + ) { self.responsePromise = eventLoop.makePromise() - super.init(eventLoop: eventLoop, headers: headers, logger: logger) + super.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: userInfoRef) } - @available(*, deprecated, renamed: "init(eventLoop:headers:logger:)") + @available(*, deprecated, renamed: "init(eventLoop:headers:logger:userInfo:)") override public init(eventLoop: EventLoop, request: HTTPRequestHead, logger: Logger) { self.responsePromise = eventLoop.makePromise() super.init(eventLoop: eventLoop, request: request, logger: logger) @@ -90,7 +104,12 @@ open class UnaryResponseCallContextImpl: UnaryResponseCallConte logger: Logger ) { self.channel = channel - super.init(eventLoop: channel.eventLoop, headers: headers, logger: logger) + super.init( + eventLoop: channel.eventLoop, + headers: headers, + logger: logger, + userInfoRef: .init(UserInfo()) + ) self.responsePromise.futureResult.whenComplete { [self, weak errorDelegate] result in switch result { diff --git a/Sources/GRPC/UserInfo.swift b/Sources/GRPC/UserInfo.swift new file mode 100644 index 000000000..f292bc6c3 --- /dev/null +++ b/Sources/GRPC/UserInfo.swift @@ -0,0 +1,106 @@ +/* + * Copyright 2020, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/// `UserInfo` is a dictionary for heterogeneously typed values with type safe access to the stored +/// values. +/// +/// Values are keyed by a type conforming to the `UserInfo.Key` protocol. The protocol requires an +/// `associatedtype`: the type of the value the key is paired with. A key can be created using a +/// caseless `enum`, for example: +/// +/// ``` +/// enum IDKey: UserInfo.Key { +/// typealias Value = Int +/// } +/// ``` +/// +/// Values can be set and retrieved from `UserInfo` by subscripting with the key: +/// +/// ``` +/// userInfo[IDKey.self] = 42 +/// let id = userInfo[IDKey.self] // id = 42 +/// +/// userInfo[IDKey.self] = nil +/// ``` +/// +/// More convenient access can be provided with helper extensions on `UserInfo`: +/// +/// ``` +/// extension UserInfo { +/// var id: IDKey.Value? { +/// get { self[IDKey.self] } +/// set { self[IDKey.self] = newValue } +/// } +/// } +/// ``` +public struct UserInfo: CustomStringConvertible { + private var storage: [AnyUserInfoKey: Any] + + /// A protocol for a key. + public typealias Key = UserInfoKey + + /// Create an empty 'UserInfo'. + public init() { + self.storage = [:] + } + + /// Allows values to be set and retrieved in a type safe way. + public subscript(key: Key.Type) -> Key.Value? { + get { + if let anyValue = self.storage[AnyUserInfoKey(key)] { + // The types must line up here. + return (anyValue as! Key.Value) + } else { + return nil + } + } + set { + self.storage[AnyUserInfoKey(key)] = newValue + } + } + + public var description: String { + return "[" + self.storage.map { key, value in + "\(key): \(value)" + }.joined(separator: ", ") + "]" + } + + /// A `UserInfoKey` wrapper. + private struct AnyUserInfoKey: Hashable, CustomStringConvertible { + private let keyType: Any.Type + + var description: String { + return String(describing: self.keyType.self) + } + + init(_ keyType: Key.Type) { + self.keyType = keyType + } + + static func == (lhs: AnyUserInfoKey, rhs: AnyUserInfoKey) -> Bool { + return ObjectIdentifier(lhs.keyType) == ObjectIdentifier(rhs.keyType) + } + + func hash(into hasher: inout Hasher) { + hasher.combine(ObjectIdentifier(self.keyType)) + } + } +} + +public protocol UserInfoKey { + /// The type of `Value` identified by this key. + associatedtype Value +} diff --git a/Tests/GRPCTests/InterceptorsTests.swift b/Tests/GRPCTests/InterceptorsTests.swift index 305f5c4e4..0ff3c90fe 100644 --- a/Tests/GRPCTests/InterceptorsTests.swift +++ b/Tests/GRPCTests/InterceptorsTests.swift @@ -137,6 +137,9 @@ class HelloWorldProvider: Helloworld_GreeterProvider { request: Helloworld_HelloRequest, context: StatusOnlyCallContext ) -> EventLoopFuture { + // Since we're auth'd, the 'userInfo' should have some magic set. + assertThat(context.userInfo.magic, .is("Magic")) + let response = Helloworld_HelloReply.with { $0.message = "Hello, \(request.name), you're authorized!" } @@ -166,7 +169,8 @@ class NotReallyAuthServerInterceptor: ) { switch part { case let .metadata(headers): - if headers.first(name: "authorization") == "Magic" { + if let auth = headers.first(name: "authorization"), auth == "Magic" { + context.userInfo.magic = auth context.receive(part) } else { // Not auth'd. Fail the RPC. @@ -340,3 +344,18 @@ private class ReversingInterceptors: Echo_EchoClientInterceptorFactoryProtocol { return self.interceptors } } + +private enum MagicKey: UserInfo.Key { + typealias Value = String +} + +extension UserInfo { + fileprivate var magic: MagicKey.Value? { + get { + return self[MagicKey.self] + } + set { + self[MagicKey.self] = newValue + } + } +} diff --git a/Tests/GRPCTests/ServerInterceptorPipelineTests.swift b/Tests/GRPCTests/ServerInterceptorPipelineTests.swift index a47222ea4..2b5ff6b1a 100644 --- a/Tests/GRPCTests/ServerInterceptorPipelineTests.swift +++ b/Tests/GRPCTests/ServerInterceptorPipelineTests.swift @@ -40,6 +40,7 @@ class ServerInterceptorPipelineTests: GRPCTestCase { eventLoop: self.embeddedEventLoop, path: path, callType: callType, + userInfoRef: Ref(UserInfo()), interceptors: interceptors, onRequestPart: onRequestPart, onResponsePart: onResponsePart diff --git a/Tests/GRPCTests/UserInfoTests.swift b/Tests/GRPCTests/UserInfoTests.swift new file mode 100644 index 000000000..8e1bc14dd --- /dev/null +++ b/Tests/GRPCTests/UserInfoTests.swift @@ -0,0 +1,87 @@ +/* + * Copyright 2020, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import GRPC + +class UserInfoTests: GRPCTestCase { + func testWithSubscript() { + var userInfo = UserInfo() + + userInfo[FooKey.self] = "foo" + assertThat(userInfo[FooKey.self], .is("foo")) + + userInfo[BarKey.self] = 42 + assertThat(userInfo[BarKey.self], .is(42)) + + userInfo[FooKey.self] = nil + assertThat(userInfo[FooKey.self], .is(.nil())) + + userInfo[BarKey.self] = nil + assertThat(userInfo[BarKey.self], .is(.nil())) + } + + func testWithExtensions() { + var userInfo = UserInfo() + + userInfo.foo = "foo" + assertThat(userInfo.foo, .is("foo")) + + userInfo.bar = 42 + assertThat(userInfo.bar, .is(42)) + + userInfo.foo = nil + assertThat(userInfo.foo, .is(.nil())) + + userInfo.bar = nil + assertThat(userInfo.bar, .is(.nil())) + } + + func testDescription() { + var userInfo = UserInfo() + assertThat(String(describing: userInfo), .is("[]")) + + // (We can't test with multiple values since ordering isn't stable.) + userInfo.foo = "foo" + assertThat(String(describing: userInfo), .is("[FooKey: foo]")) + } +} + +private enum FooKey: UserInfoKey { + typealias Value = String +} + +private enum BarKey: UserInfoKey { + typealias Value = Int +} + +extension UserInfo { + fileprivate var foo: FooKey.Value? { + get { + return self[FooKey.self] + } + set { + self[FooKey.self] = newValue + } + } + + fileprivate var bar: BarKey.Value? { + get { + return self[BarKey.self] + } + set { + self[BarKey.self] = newValue + } + } +} diff --git a/Tests/GRPCTests/XCTestManifests.swift b/Tests/GRPCTests/XCTestManifests.swift index 0aee681dd..ea8e839cf 100644 --- a/Tests/GRPCTests/XCTestManifests.swift +++ b/Tests/GRPCTests/XCTestManifests.swift @@ -1032,6 +1032,17 @@ extension TimeLimitTests { ] } +extension UserInfoTests { + // DO NOT MODIFY: This is autogenerated, use: + // `swift test --generate-linuxmain` + // to regenerate. + static let __allTests__UserInfoTests = [ + ("testDescription", testDescription), + ("testWithExtensions", testWithExtensions), + ("testWithSubscript", testWithSubscript), + ] +} + extension ZeroLengthWriteTests { // DO NOT MODIFY: This is autogenerated, use: // `swift test --generate-linuxmain` @@ -1131,6 +1142,7 @@ public func __allTests() -> [XCTestCaseEntry] { testCase(StopwatchTests.__allTests__StopwatchTests), testCase(StreamingRequestClientCallTests.__allTests__StreamingRequestClientCallTests), testCase(TimeLimitTests.__allTests__TimeLimitTests), + testCase(UserInfoTests.__allTests__UserInfoTests), testCase(ZeroLengthWriteTests.__allTests__ZeroLengthWriteTests), testCase(ZlibTests.__allTests__ZlibTests), ]