diff --git a/ApolloWebSocket.xcodeproj/project.pbxproj b/ApolloWebSocket.xcodeproj/project.pbxproj index 89e93eebcb..7a48b8844a 100644 --- a/ApolloWebSocket.xcodeproj/project.pbxproj +++ b/ApolloWebSocket.xcodeproj/project.pbxproj @@ -23,6 +23,7 @@ 9F28B6D520720F2F00144A00 /* Apollo.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 9F28B6D420720F2F00144A00 /* Apollo.framework */; }; 9F28B6D920720FD200144A00 /* ApolloTestSupport.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 9F28B6D820720FD100144A00 /* ApolloTestSupport.framework */; }; 9F28B6DB2072101200144A00 /* StarWarsAPI.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 9F28B6DA2072101200144A00 /* StarWarsAPI.framework */; }; + D1ACF61D23715AF30042E200 /* Atomic.swift in Sources */ = {isa = PBXBuildFile; fileRef = D1ACF61B23715AF30042E200 /* Atomic.swift */; }; /* End PBXBuildFile section */ /* Begin PBXContainerItemProxy section */ @@ -73,6 +74,7 @@ 9F28B6D420720F2F00144A00 /* Apollo.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = Apollo.framework; sourceTree = BUILT_PRODUCTS_DIR; }; 9F28B6D820720FD100144A00 /* ApolloTestSupport.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = ApolloTestSupport.framework; sourceTree = BUILT_PRODUCTS_DIR; }; 9F28B6DA2072101200144A00 /* StarWarsAPI.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = StarWarsAPI.framework; sourceTree = BUILT_PRODUCTS_DIR; }; + D1ACF61B23715AF30042E200 /* Atomic.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Atomic.swift; sourceTree = ""; }; /* End PBXFileReference section */ /* Begin PBXFrameworksBuildPhase section */ @@ -141,6 +143,7 @@ 9B1CCDE223611606007C9032 /* WebSocketTask.swift */, 7270746B206D111A00C131F6 /* WebSocketTransport.swift */, 7270746C206D111A00C131F6 /* Info.plist */, + D1ACF61923715AF30042E200 /* Utilities */, ); name = ApolloWebSocket; path = Sources/ApolloWebSocket; @@ -181,6 +184,14 @@ name = Products; sourceTree = ""; }; + D1ACF61923715AF30042E200 /* Utilities */ = { + isa = PBXGroup; + children = ( + D1ACF61B23715AF30042E200 /* Atomic.swift */, + ); + path = Utilities; + sourceTree = ""; + }; /* End PBXGroup section */ /* Begin PBXNativeTarget section */ @@ -310,6 +321,7 @@ 9B1CCDDF236110C3007C9032 /* WebSocketError.swift in Sources */, 7270746D206D111A00C131F6 /* SplitNetworkTransport.swift in Sources */, 9B1CCDE323611606007C9032 /* WebSocketTask.swift in Sources */, + D1ACF61D23715AF30042E200 /* Atomic.swift in Sources */, 7270746E206D111A00C131F6 /* WebSocketTransport.swift in Sources */, 9B1CCDE123611580007C9032 /* OperationMessage.swift in Sources */, 9B1CCDDB23610CDC007C9032 /* ApolloWebSocket.swift in Sources */, diff --git a/Sources/ApolloWebSocket/ApolloWebSocket.swift b/Sources/ApolloWebSocket/ApolloWebSocket.swift index ad79679c1b..29a9fab8bc 100644 --- a/Sources/ApolloWebSocket/ApolloWebSocket.swift +++ b/Sources/ApolloWebSocket/ApolloWebSocket.swift @@ -14,6 +14,9 @@ public protocol ApolloWebSocketClient: WebSocketClient { /// The URLRequest used on connection. var request: URLRequest { get set } + + /// Queue where the callbacks are executed + var callbackQueue: DispatchQueue { get set } } // MARK: - WebSocket diff --git a/Sources/ApolloWebSocket/Utilities/Atomic.swift b/Sources/ApolloWebSocket/Utilities/Atomic.swift new file mode 100644 index 0000000000..4c1c56f1fa --- /dev/null +++ b/Sources/ApolloWebSocket/Utilities/Atomic.swift @@ -0,0 +1,36 @@ +import Foundation + +class Atomic { + private let lock = NSLock() + private var _value: T + + init(_ value: T) { + _value = value + } + + var value: T { + get { + lock.lock() + defer { lock.unlock() } + + return _value + } + set { + lock.lock() + defer { lock.unlock() } + + _value = newValue + } + } +} + +extension Atomic where T == Int { + + func increment() -> T { + lock.lock() + defer { lock.unlock() } + + _value += 1 + return _value + } +} diff --git a/Sources/ApolloWebSocket/WebSocketTransport.swift b/Sources/ApolloWebSocket/WebSocketTransport.swift index 7bde782dde..73702077f5 100644 --- a/Sources/ApolloWebSocket/WebSocketTransport.swift +++ b/Sources/ApolloWebSocket/WebSocketTransport.swift @@ -25,9 +25,9 @@ public class WebSocketTransport { public static var provider: ApolloWebSocketClient.Type = ApolloWebSocket.self public weak var delegate: WebSocketTransportDelegate? - var reconnect = false + let reconnect: Atomic = Atomic(false) var websocket: ApolloWebSocketClient - var error: Error? = nil + let error: Atomic = Atomic(nil) let serializationFormat = JSONSerializationFormat.self private let requestCreator: RequestCreator @@ -40,10 +40,11 @@ public class WebSocketTransport { private var subscribers = [String: (Result) -> Void]() private var subscriptions : [String: String] = [:] + private let processingQueue = DispatchQueue(label: "com.apollographql.WebSocketTransport") private let sendOperationIdentifiers: Bool private let reconnectionInterval: TimeInterval - fileprivate var sequenceNumber = 0 + fileprivate let sequenceNumberCounter = Atomic(0) fileprivate var reconnected = false /// NOTE: Setting this won't override immediately if the socket is still connected, only on reconnection. @@ -87,6 +88,7 @@ public class WebSocketTransport { self.websocket.request.setValue(self.clientVersion, forHTTPHeaderField: WebSocketTransport.headerFieldNameClientVersion) self.websocket.delegate = self self.websocket.connect() + self.websocket.callbackQueue = processingQueue } public func isConnected() -> Bool { @@ -174,7 +176,7 @@ public class WebSocketTransport { } public func initServer(reconnect: Bool = true) { - self.reconnect = reconnect + self.reconnect.value = reconnect self.acked = false if let str = OperationMessage(payload: self.connectingPayload, type: .connectionInit).rawMessage { @@ -184,12 +186,17 @@ public class WebSocketTransport { } public func closeConnection() { - self.reconnect = false - if let str = OperationMessage(type: .connectionTerminate).rawMessage { - write(str) + self.reconnect.value = false + + let str = OperationMessage(type: .connectionTerminate).rawMessage + processingQueue.async { + if let str = str { + self.write(str) + } + + self.queue.removeAll() + self.subscriptions.removeAll() } - self.queue.removeAll() - self.subscriptions.removeAll() } private func write(_ str: String, force forced: Bool = false, id: Int? = nil) { @@ -213,35 +220,36 @@ public class WebSocketTransport { websocket.delegate = nil } - private func nextSequenceNumber() -> Int { - sequenceNumber += 1 - return sequenceNumber - } - func sendHelper(operation: Operation, resultHandler: @escaping (_ result: Result) -> Void) -> String? { let body = requestCreator.requestBody(for: operation, sendOperationIdentifiers: self.sendOperationIdentifiers) - let sequenceNumber = "\(nextSequenceNumber())" + let sequenceNumber = "\(sequenceNumberCounter.increment())" guard let message = OperationMessage(payload: body, id: sequenceNumber).rawMessage else { return nil } - - write(message) + + processingQueue.async { + self.write(message) - subscribers[sequenceNumber] = resultHandler - if operation.operationType == .subscription { - subscriptions[sequenceNumber] = message + self.subscribers[sequenceNumber] = resultHandler + if operation.operationType == .subscription { + self.subscriptions[sequenceNumber] = message + } } return sequenceNumber } public func unsubscribe(_ subscriptionId: String) { - if let str = OperationMessage(id: subscriptionId, type: .stop).rawMessage { - write(str) + let str = OperationMessage(id: subscriptionId, type: .stop).rawMessage + + processingQueue.async { + if let str = str { + self.write(str) + } + self.subscribers.removeValue(forKey: subscriptionId) + self.subscriptions.removeValue(forKey: subscriptionId) } - subscribers.removeValue(forKey: subscriptionId) - subscriptions.removeValue(forKey: subscriptionId) } } @@ -249,7 +257,7 @@ public class WebSocketTransport { extension WebSocketTransport: NetworkTransport { public func send(operation: Operation, completionHandler: @escaping (_ result: Result,Error>) -> Void) -> Cancellable { - if let error = self.error { + if let error = self.error.value { completionHandler(.failure(error)) return EmptyCancellable() } @@ -271,7 +279,7 @@ extension WebSocketTransport: NetworkTransport { extension WebSocketTransport: WebSocketDelegate { public func websocketDidConnect(socket: WebSocketClient) { - self.error = nil + self.error.value = nil initServer() if reconnected { self.delegate?.webSocketTransportDidReconnect(self) @@ -290,16 +298,16 @@ extension WebSocketTransport: WebSocketDelegate { public func websocketDidDisconnect(socket: WebSocketClient, error: Error?) { // report any error to all subscribers if let error = error { - self.error = WebSocketError(payload: nil, error: error, kind: .networkError) + self.error.value = WebSocketError(payload: nil, error: error, kind: .networkError) self.notifyErrorAllHandlers(error) } else { - self.error = nil + self.error.value = nil } - self.delegate?.webSocketTransport(self, didDisconnectWithError: self.error) + self.delegate?.webSocketTransport(self, didDisconnectWithError: self.error.value) acked = false // need new connect and ack before sending - if reconnect { + if reconnect.value { DispatchQueue.main.asyncAfter(deadline: .now() + reconnectionInterval) { self.websocket.connect() } diff --git a/Tests/ApolloWebsocketTests/MockWebSocket.swift b/Tests/ApolloWebsocketTests/MockWebSocket.swift index 8d965f3c5b..ae0208f7cd 100644 --- a/Tests/ApolloWebsocketTests/MockWebSocket.swift +++ b/Tests/ApolloWebsocketTests/MockWebSocket.swift @@ -2,6 +2,9 @@ import Starscream @testable import ApolloWebSocket class MockWebSocket: ApolloWebSocketClient { + + var callbackQueue: DispatchQueue = DispatchQueue.main + var pongDelegate: WebSocketPongDelegate? var request: URLRequest @@ -15,8 +18,16 @@ class MockWebSocket: ApolloWebSocketClient { self.request = URLRequest(url: URL(string: "http://localhost:8080")!) } + open func reportDidConnect() { + callbackQueue.async { + self.delegate?.websocketDidConnect(socket: self) + } + } + open func write(string: String, completion: (() -> ())?) { - delegate?.websocketDidReceiveMessage(socket: self, text: string) + callbackQueue.async { + self.delegate?.websocketDidReceiveMessage(socket: self, text: string) + } } open func write(data: Data, completion: (() -> ())?) { diff --git a/Tests/ApolloWebsocketTests/StarWarsSubscriptionTests.swift b/Tests/ApolloWebsocketTests/StarWarsSubscriptionTests.swift index d2c331e816..94c456aad3 100644 --- a/Tests/ApolloWebsocketTests/StarWarsSubscriptionTests.swift +++ b/Tests/ApolloWebsocketTests/StarWarsSubscriptionTests.swift @@ -6,14 +6,17 @@ import StarWarsAPI class StarWarsSubscriptionTests: XCTestCase { let SERVER: String = "ws://localhost:8080/websocket" + let concurrentQueue = DispatchQueue(label: "com.apollographql.testing", attributes: .concurrent) var client: ApolloClient! + var webSocketTransport: WebSocketTransport! override func setUp() { super.setUp() - let networkTransport = WebSocketTransport(request: URLRequest(url: URL(string: SERVER)!)) - client = ApolloClient(networkTransport: networkTransport) + WebSocketTransport.provider = ApolloWebSocket.self + webSocketTransport = WebSocketTransport(request: URLRequest(url: URL(string: SERVER)!)) + client = ApolloClient(networkTransport: webSocketTransport) } // MARK: Subscriptions @@ -252,4 +255,121 @@ class StarWarsSubscriptionTests: XCTestCase { subJedi.cancel() subNewHope.cancel() } + + // MARK: Data races tests + + func testConcurrentSubscribing() { + let firstSubscription = ReviewAddedSubscription(episode: .empire) + let secondSubscription = ReviewAddedSubscription(episode: .empire) + + let expectation = self.expectation(description: "Subscribers connected and received events") + expectation.expectedFulfillmentCount = 2 + + var sub1: Cancellable? + var sub2: Cancellable? + + concurrentQueue.async { + sub1 = self.client.subscribe(subscription: firstSubscription) { _ in + expectation.fulfill() + } + } + + concurrentQueue.async { + sub2 = self.client.subscribe(subscription: secondSubscription) { _ in + expectation.fulfill() + } + } + + // dispatched with a barrier flag to make sure + // this is performed after subscription calls + concurrentQueue.sync(flags: .barrier) { + // dispatched on the processing queue to make sure + // this is performed after subscribers are processed + self.webSocketTransport.websocket.callbackQueue.async { + _ = self.client.perform(mutation: CreateReviewForEpisodeMutation(episode: .empire, review: ReviewInput(stars: 5, commentary: "The greatest movie ever!"))) + } + } + + waitForExpectations(timeout: 10, handler: nil) + sub1?.cancel() + sub2?.cancel() + } + + func testConcurrentSubscriptionCancellations() { + let firstSubscription = ReviewAddedSubscription(episode: .empire) + let secondSubscription = ReviewAddedSubscription(episode: .empire) + + let expectation = self.expectation(description: "Subscriptions cancelled") + expectation.expectedFulfillmentCount = 2 + let invertedExpectation = self.expectation(description: "Subscription received callback - expecting timeout") + invertedExpectation.isInverted = true + + let sub1 = client.subscribe(subscription: firstSubscription) { _ in + invertedExpectation.fulfill() + } + let sub2 = client.subscribe(subscription: secondSubscription) { _ in + invertedExpectation.fulfill() + } + + concurrentQueue.async { + sub1.cancel() + expectation.fulfill() + } + concurrentQueue.async { + sub2.cancel() + expectation.fulfill() + } + + wait(for: [expectation], timeout: 10) + + _ = self.client.perform(mutation: CreateReviewForEpisodeMutation(episode: .empire, review: ReviewInput(stars: 5, commentary: "The greatest movie ever!"))) + + wait(for: [invertedExpectation], timeout: 2) + } + + func testConcurrentSubscriptionAndConnectionClose() { + let empireReviewSubscription = ReviewAddedSubscription(episode: .empire) + let expectation = self.expectation(description: "Connection closed") + let invertedExpectation = self.expectation(description: "Subscription received callback - expecting timeout") + invertedExpectation.isInverted = true + + let sub = self.client.subscribe(subscription: empireReviewSubscription) { _ in + invertedExpectation.fulfill() + } + + concurrentQueue.async { + sub.cancel() + } + concurrentQueue.async { + self.webSocketTransport.closeConnection() + expectation.fulfill() + } + + wait(for: [expectation], timeout: 10) + + _ = self.client.perform(mutation: CreateReviewForEpisodeMutation(episode: .empire, review: ReviewInput(stars: 5, commentary: "The greatest movie ever!"))) + + wait(for: [invertedExpectation], timeout: 2) + } + + func testConcurrentConnectAndCloseConnection() { + WebSocketTransport.provider = MockWebSocket.self + let webSocketTransport = WebSocketTransport(request: URLRequest(url: URL(string: SERVER)!)) + let expectation = self.expectation(description: "Connection closed") + expectation.expectedFulfillmentCount = 2 + + concurrentQueue.async { + if let websocket = webSocketTransport.websocket as? MockWebSocket { + websocket.reportDidConnect() + expectation.fulfill() + } + } + + concurrentQueue.async { + webSocketTransport.closeConnection() + expectation.fulfill() + } + + waitForExpectations(timeout: 10, handler: nil) + } }