From 0eea842ca8a6ae696ca40267642577766b03c0fc Mon Sep 17 00:00:00 2001 From: danthorpe Date: Sat, 3 Aug 2024 15:00:36 +0100 Subject: [PATCH] feat: traced() fixes: #38 --- Sources/Helpers/Data+Crypto.swift | 14 +++ Sources/Helpers/UniqueIdentifier.swift | 72 +++++++++++ Sources/Networking/Components/Traced.swift | 113 ++++++++++++++++++ Sources/TestSupport/Mocked.swift | 7 ++ Sources/TestSupport/NetworkingTestCase.swift | 1 + .../Components/TracedTests.swift | 52 ++++++++ 6 files changed, 259 insertions(+) create mode 100644 Sources/Helpers/Data+Crypto.swift create mode 100644 Sources/Helpers/UniqueIdentifier.swift create mode 100644 Sources/Networking/Components/Traced.swift create mode 100644 Tests/NetworkingTests/Components/TracedTests.swift diff --git a/Sources/Helpers/Data+Crypto.swift b/Sources/Helpers/Data+Crypto.swift new file mode 100644 index 00000000..ba017c01 --- /dev/null +++ b/Sources/Helpers/Data+Crypto.swift @@ -0,0 +1,14 @@ +import CryptoKit +import Foundation + +extension Data { + + static func secureRandomData(length: UInt) -> Data? { + let count = Int(length) + var bytes = [Int8](repeating: 0, count: count) + guard errSecSuccess == SecRandomCopyBytes(kSecRandomDefault, count, &bytes) else { + return nil + } + return Data(bytes: bytes, count: count) + } +} diff --git a/Sources/Helpers/UniqueIdentifier.swift b/Sources/Helpers/UniqueIdentifier.swift new file mode 100644 index 00000000..8b7f53ef --- /dev/null +++ b/Sources/Helpers/UniqueIdentifier.swift @@ -0,0 +1,72 @@ +import ConcurrencyExtras +import Foundation + +package enum UniqueIdentifier: Hashable { + package enum Format: Hashable { + case base64, hex + } + case secureBytes(length: UInt = 10, format: Format) +} + +extension UniqueIdentifier { + + func generate() -> String { + switch self { + case let .secureBytes(length, _): + var data = Data() + repeat { + data = .secureRandomData(length: length) ?? Data() + } while data.isEmpty + return format(data: data) + } + } + + func format(data: Data) -> String { + switch self { + case .secureBytes(_, .base64): + return data.base64EncodedString(options: []) + case .secureBytes(_, .hex): + return data.map { String(format: "%02hhx", $0) }.joined() + } + } +} + +// MARK: - Generator + +extension UniqueIdentifier { + package struct Generator: Sendable { + private let generate: @Sendable () -> String + + package init(_ id: UniqueIdentifier) { + self.init { id.generate() } + } + + package init(generate: @escaping @Sendable () -> String) { + self.generate = generate + } + + @discardableResult + package func callAsFunction() -> String { + generate() + } + } +} + +extension UniqueIdentifier.Generator { + package static func constant(_ id: UniqueIdentifier) -> Self { + let generation = id.generate() + return Self { generation } + } + + package static func incrementing(_ id: UniqueIdentifier) -> Self { + let sequence = LockIsolated(0) + return Self { + let number = sequence.withValue { + $0 += 1 + return $0 + } + let data = withUnsafeBytes(of: number.bigEndian) { Data($0) } + return id.format(data: data) + } + } +} diff --git a/Sources/Networking/Components/Traced.swift b/Sources/Networking/Components/Traced.swift new file mode 100644 index 00000000..f17d9cf8 --- /dev/null +++ b/Sources/Networking/Components/Traced.swift @@ -0,0 +1,113 @@ +import Dependencies +import DependenciesMacros +import Foundation +import HTTPTypes +import Helpers + +extension NetworkingComponent { + + /// Generates a HTTP Trace Parent header for each request. + /// + /// - See-Also: [Trace-Context](https://www.w3.org/TR/trace-context/) + public func traced() -> some NetworkingComponent { + modified(Traced()) + } +} + +private struct Traced: NetworkingModifier { + @Dependency(\.traceParentGenerator) var generate + func resolve(upstream: some NetworkingComponent, request: HTTPRequestData) -> HTTPRequestData { + guard nil == request.traceParent else { + return request + } + var copy = request + copy.traceParent = generate() + return copy + } +} + +extension HTTPField.Name { + public static let traceparent = HTTPField.Name("traceparent")! +} + +extension HTTPRequestData { + package fileprivate(set) var traceParent: TraceParent? { + get { self[option: TraceParent.self] } + set { + self[option: TraceParent.self] = newValue + self.headerFields[.traceparent] = newValue?.description + } + } + + public var traceId: String? { + traceParent?.traceId + } + + public var parentId: String? { + traceParent?.parentId + } +} + +public struct TraceParent: Sendable, HTTPRequestDataOption { + public static var defaultOption: Self? + + // Current version of the spec only supports 01 flag + // Future versions of the spec will require support for bit-field mask + public let traceId: String + public let parentId: String + + public var description: String { + "00-\(traceId)-\(parentId)-01" + } + + public init(traceId: String, parentId: String) { + self.traceId = traceId + self.parentId = parentId + } +} + +// MARK: - Generator + +@DependencyClient +public struct TraceParentGenerator: Sendable { + public var generate: @Sendable () -> TraceParent = { + TraceParent(traceId: "dummy-trace-id", parentId: "dummy-parent-id") + } + + package func callAsFunction() -> TraceParent { + generate() + } +} + +extension TraceParentGenerator: DependencyKey { + public static let liveValue = { + let traceId = UniqueIdentifier.Generator(.secureBytes(length: 16, format: .hex)) + let parentId = UniqueIdentifier.Generator(.secureBytes(length: 8, format: .hex)) + return TraceParentGenerator { + TraceParent( + traceId: traceId(), + parentId: parentId() + ) + } + }() +} + +extension DependencyValues { + public var traceParentGenerator: TraceParentGenerator { + get { self[TraceParentGenerator.self] } + set { self[TraceParentGenerator.self] = newValue } + } +} + +extension TraceParentGenerator { + public static let incrementing = { + let traceId = UniqueIdentifier.Generator.incrementing(.secureBytes(length: 16, format: .hex)) + let parentId = UniqueIdentifier.Generator.incrementing(.secureBytes(length: 8, format: .hex)) + return TraceParentGenerator { + TraceParent( + traceId: traceId(), + parentId: parentId() + ) + } + }() +} diff --git a/Sources/TestSupport/Mocked.swift b/Sources/TestSupport/Mocked.swift index 81c20f84..1e9fe5a1 100644 --- a/Sources/TestSupport/Mocked.swift +++ b/Sources/TestSupport/Mocked.swift @@ -2,6 +2,13 @@ import Networking extension NetworkingComponent { + /// Mock all requests with a stub + public func mocked( + all stub: StubbedResponseStream + ) -> some NetworkingComponent { + mocked(stub) { _ in true } + } + /// Mock a given request with a stub public func mocked( _ request: HTTPRequestData, diff --git a/Sources/TestSupport/NetworkingTestCase.swift b/Sources/TestSupport/NetworkingTestCase.swift index 6f89f899..bda2471d 100644 --- a/Sources/TestSupport/NetworkingTestCase.swift +++ b/Sources/TestSupport/NetworkingTestCase.swift @@ -23,6 +23,7 @@ open class NetworkingTestCase: XCTestCase { ) { withDependencies { $0.shortID = shortIdGenerator ?? .incrementing + $0.traceParentGenerator = .incrementing $0.continuousClock = continuousClock ?? TestClock() updateValuesForOperation(&$0) } operation: { diff --git a/Tests/NetworkingTests/Components/TracedTests.swift b/Tests/NetworkingTests/Components/TracedTests.swift new file mode 100644 index 00000000..4b5eb697 --- /dev/null +++ b/Tests/NetworkingTests/Components/TracedTests.swift @@ -0,0 +1,52 @@ +import Foundation +import TestSupport +import XCTest + +@testable import Networking + +final class TracedTests: NetworkingTestCase { + override func invokeTest() { + withTestDependencies { + super.invokeTest() + } + } + + func test__request_includes_trace() async throws { + let reporter = TestReporter() + + let network = TerminalNetworkingComponent() + .mocked(all: .ok()) + .reported(by: reporter) + .traced() + + try await withThrowingTaskGroup(of: HTTPResponseData.self) { group in + for _ in 0 ..< 10 { + group.addTask { + try await network.data(HTTPRequestData()) + } + } + + var responses: [HTTPResponseData] = [] + for try await response in group { + responses.append(response) + } + } + + let sentRequests = await reporter.requests + + XCTAssertEqual( + sentRequests.map(\.headerFields[.traceparent]), + [ + "00-0000000000000001-0000000000000001-01", + "00-0000000000000002-0000000000000002-01", + "00-0000000000000003-0000000000000003-01", + "00-0000000000000004-0000000000000004-01", + "00-0000000000000005-0000000000000005-01", + "00-0000000000000006-0000000000000006-01", + "00-0000000000000007-0000000000000007-01", + "00-0000000000000008-0000000000000008-01", + "00-0000000000000009-0000000000000009-01", + "00-000000000000000a-000000000000000a-01", + ]) + } +}