Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(API): make sure unsubscribe is invoked when subscription cancelled #3619

Merged
merged 2 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ actor AppSyncRealTimeClient: AppSyncRealTimeClientProtocol {
self.state.value == .connected
}

internal var numberOfSubscriptions: Int {
self.subscriptions.count
}

/**
Creates a new AppSyncRealTimeClient with endpoint, requestInterceptor and webSocketClient.
- Parameters:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ public class AWSGraphQLSubscriptionTaskRunner<R: Decodable>: InternalTaskRunner,
self.apiAuthProviderFactory = apiAuthProviderFactory
}

/// When the top-level AmplifyThrowingSequence is canceled, this cancel method is invoked.
/// In this situation, we need to send the disconnected event because
/// the top-level AmplifyThrowingSequence is terminated immediately upon cancellation.
public func cancel() {
self.send(GraphQLSubscriptionEvent<R>.connection(.disconnected))
Task {
Expand Down Expand Up @@ -210,12 +213,7 @@ final public class AWSGraphQLSubscriptionOperation<R: Decodable>: GraphQLSubscri

override public func cancel() {
super.cancel()

Task { [weak self] in
guard let self else {
return
}

Task {
guard let appSyncRealTimeClient = self.appSyncRealTimeClient else {
return
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,65 @@ class GraphQLModelBasedTests: XCTestCase {
await fulfillment(of: [progressInvoked], timeout: TestCommonConstants.networkTimeout)
}


/// Given: Several subscriptions with Amplify API plugin
/// When: Cancel subscriptions
/// Then: AppSync real time client automatically unsubscribe and remove the subscription
func testCancelledSubscription_automaticallyUnsubscribeAndRemoved() async throws {
let numberOfSubscription = 5
let allSubscribedExpectation = expectation(description: "All subscriptions are subscribed")
allSubscribedExpectation.expectedFulfillmentCount = numberOfSubscription

let subscriptions = (0..<5).map { _ in
Amplify.API.subscribe(request: .subscription(of: Comment.self, type: .onCreate))
}
subscriptions.forEach { subscription in
Task {
do {
for try await subscriptionEvent in subscription {
switch subscriptionEvent {
case .connection(let state):
switch state {
case .connecting:
break
case .connected:
allSubscribedExpectation.fulfill()
case .disconnected:
break
}
case .data(let result):
switch result {
case .success: break
case .failure(let error):
XCTFail("\(error)")
}
}
}
} catch {
XCTFail("Unexpected subscription failure")
}
}
}

await fulfillment(of: [allSubscribedExpectation], timeout: 3)
if let appSyncRealTimeClientFactory =
getUnderlyingAPIPlugin()?.appSyncRealTimeClientFactory as? AppSyncRealTimeClientFactory,
let appSyncRealTimeClient =
await appSyncRealTimeClientFactory.apiToClientCache.values.first as? AppSyncRealTimeClient
{
var appSyncSubscriptions = await appSyncRealTimeClient.numberOfSubscriptions
XCTAssertEqual(appSyncSubscriptions, numberOfSubscription)

subscriptions.forEach { $0.cancel() }
try await Task.sleep(seconds: 2)
appSyncSubscriptions = await appSyncRealTimeClient.numberOfSubscriptions
XCTAssertEqual(appSyncSubscriptions, 0)

} else {
XCTFail("There should be at least one AppSyncRealTimeClient instance")
}
}

// MARK: Helpers

func createPost(id: String, title: String) async throws -> Post? {
Expand Down Expand Up @@ -499,4 +558,8 @@ class GraphQLModelBasedTests: XCTestCase {
throw error
}
}

func getUnderlyingAPIPlugin() -> AWSAPIPlugin? {
return Amplify.API.plugins["awsAPIPlugin"] as? AWSAPIPlugin
}
}
Loading