From 81f06d159125b80638ed409740457a2576e6d8fa Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 3 Nov 2023 15:39:23 +0000 Subject: [PATCH 01/12] Update existing test cases --- tests-integration/poller_test.go | 68 +++++++++++++++----------------- 1 file changed, 31 insertions(+), 37 deletions(-) diff --git a/tests-integration/poller_test.go b/tests-integration/poller_test.go index ed66ed26..720bd393 100644 --- a/tests-integration/poller_test.go +++ b/tests-integration/poller_test.go @@ -119,18 +119,14 @@ func TestSecondPollerFiltersToDevice(t *testing.T) { m.MatchResponse(t, res, m.MatchToDeviceMessages([]json.RawMessage{wantMsg})) } -// Test that the poller makes a best-effort attempt to integrate state seen in a -// v2 sync state block. Our strategy for doing so is to prepend any unknown state events -// to the start of the v2 sync response's timeline, which should then be visible to -// sync v3 clients as ordinary state events in the room timeline. func TestPollerHandlesUnknownStateEventsOnIncrementalSync(t *testing.T) { - // FIXME: this should resolve once we update downstream caches - t.Skip("We will never see the name/PL event in the timeline with the new code due to those events being part of the state block.") pqString := testutils.PrepareDBConnectionString() v2 := runTestV2Server(t) v3 := runTestServer(t, v2, pqString) defer v2.close() defer v3.close() + + t.Log("Alice creates a room.") v2.addAccount(t, alice, aliceToken) const roomID = "!unimportant" v2.queueResponse(aliceToken, sync2.SyncResponse{ @@ -141,18 +137,21 @@ func TestPollerHandlesUnknownStateEventsOnIncrementalSync(t *testing.T) { }), }, }) - res := v3.mustDoV3Request(t, aliceToken, sync3.Request{ + t.Log("Alice sliding syncs, explicitly requesting power levels.") + aliceReq := sync3.Request{ Lists: map[string]sync3.RequestList{ "a": { Ranges: [][2]int64{{0, 20}}, RoomSubscription: sync3.RoomSubscription{ TimelineLimit: 10, + RequiredState: [][2]string{{"m.room.power_levels", ""}}, }, }, }, - }) + } + res := v3.mustDoV3Request(t, aliceToken, aliceReq) - t.Log("The poller receives a gappy incremental sync response with a state block. The power levels and room name have changed.") + t.Log("Alice's poller receives a gappy poll with a state block. The power levels and room name have changed.") nameEvent := testutils.NewStateEvent( t, "m.room.name", @@ -187,37 +186,26 @@ func TestPollerHandlesUnknownStateEventsOnIncrementalSync(t *testing.T) { }, }, }) + v2.waitUntilEmpty(t, aliceToken) - res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, sync3.Request{}) - m.MatchResponse( - t, - res, - m.MatchRoomSubscription( - roomID, - func(r sync3.Room) error { - // syncv2 doesn't assign any meaning to the order of events in a state - // block, so check for both possibilities - nameFirst := m.MatchRoomTimeline([]json.RawMessage{nameEvent, powerLevelsEvent, messageEvent}) - powerLevelsFirst := m.MatchRoomTimeline([]json.RawMessage{powerLevelsEvent, nameEvent, messageEvent}) - if nameFirst(r) != nil && powerLevelsFirst(r) != nil { - return fmt.Errorf("did not see state before message") - } - return nil - }, - m.MatchRoomName("banana"), - ), - ) + t.Log("Alice incremental sliding syncs.") + _, respBytes, statusCode := v3.doV3Request(t, context.Background(), aliceToken, res.Pos, sync3.Request{}) + t.Log("The server should have closed the long-polling session.") + assertUnknownPos(t, respBytes, statusCode) + + t.Log("Alice sliding syncs from scratch.") + res = v3.mustDoV3Request(t, aliceToken, aliceReq) + t.Log("Alice sees the new room name and power levels.") + m.MatchResponse(t, res, m.MatchRoomSubscription(roomID, + m.MatchRoomRequiredState([]json.RawMessage{powerLevelsEvent}), + m.MatchRoomName("banana"), + )) } // Similar to TestPollerHandlesUnknownStateEventsOnIncrementalSync. Here we are testing // that if Alice's poller sees Bob leave in a state block, the events seen in that // timeline are not visible to Bob. func TestPollerUpdatesRoomMemberTrackerOnGappySyncStateBlock(t *testing.T) { - // the room state should update to make bob no longer be a member, which should update downstream caches - // DO WE SEND THESE GAPPY STATES TO THE CLIENT? It's NOT part of the timeline, but we need to let the client - // know somehow? I think the best case here would be to invalidate that _room_ (if that were possible in the API) - // to force the client to resync the state. - t.Skip("figure out what the valid thing to do here is") pqString := testutils.PrepareDBConnectionString() v2 := runTestV2Server(t) v3 := runTestServer(t, v2, pqString) @@ -312,15 +300,21 @@ func TestPollerUpdatesRoomMemberTrackerOnGappySyncStateBlock(t *testing.T) { }, }, }) + v2.waitUntilEmpty(t, aliceToken) t.Log("Bob makes an incremental sliding sync request.") - bobRes = v3.mustDoV3RequestWithPos(t, bobToken, bobRes.Pos, sync3.Request{}) - t.Log("He should see his leave event in the room timeline.") + _, respBytes, statusCode := v3.doV3Request(t, context.Background(), bobToken, bobRes.Pos, sync3.Request{}) + assertUnknownPos(t, respBytes, statusCode) + + t.Log("Bob makes a new sliding sync session.") + bobRes = v3.mustDoV3Request(t, bobToken, syncRequest) + + t.Log("He shouldn't see any evidence of the room.") m.MatchResponse( t, bobRes, - m.MatchList("a", m.MatchV3Count(1)), - m.MatchRoomSubscription(roomID, m.MatchRoomTimelineMostRecent(1, []json.RawMessage{bobLeave})), + m.MatchList("a", m.MatchV3Count(0)), + m.MatchRoomSubscriptionsStrict(nil), ) } From e872b18fe61fa1f1ba4cc3d649206b6fa344e35f Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 3 Nov 2023 15:39:47 +0000 Subject: [PATCH 02/12] New test cases --- tests-integration/poller_test.go | 665 +++++++++++++++++++++++++++++++ 1 file changed, 665 insertions(+) diff --git a/tests-integration/poller_test.go b/tests-integration/poller_test.go index 720bd393..5ce3b4dd 100644 --- a/tests-integration/poller_test.go +++ b/tests-integration/poller_test.go @@ -594,3 +594,668 @@ func TestTimelineStopsLoadingWhenMissingPrevious(t *testing.T) { m.MatchRoomPrevBatch("dummyPrevBatch"), )) } + +// The "prepend state events" mechanism added in +// https://github.com/matrix-org/sliding-sync/pull/71 ensured that the proxy +// communicated state events in "gappy syncs" to users. But it did so via Accumulate, +// which made one snapshot for each state event. This was not an accurate model of the +// room's history (the state block comes in no particular order) and had awful +// performance for large gappy states. +// +// We now want to handle these in Initialise, making a single snapshot for the state +// block. This test ensures that is the case. The logic is very similar to the e2e test +// TestGappyState. +func TestGappyStateDoesNotAccumulateTheStateBlock(t *testing.T) { + pqString := testutils.PrepareDBConnectionString() + v2 := runTestV2Server(t) + defer v2.close() + v3 := runTestServer(t, v2, pqString) + defer v3.close() + + v2.addAccount(t, alice, aliceToken) + v2.addAccount(t, bob, bobToken) + + t.Log("Alice creates a room, sets its name and sends a message.") + const roomID = "!unimportant" + name1 := testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]any{ + "name": "wonderland", + }) + msg1 := testutils.NewMessageEvent(t, alice, "0118 999 881 999 119 7253") + + joinTimeline := v2JoinTimeline(roomEvents{ + roomID: roomID, + events: append( + createRoomState(t, alice, time.Now()), + name1, + msg1, + ), + }) + v2.queueResponse(aliceToken, sync2.SyncResponse{ + Rooms: sync2.SyncRoomsResponse{ + Join: joinTimeline, + }, + }) + + t.Log("Alice sliding syncs with a huge timeline limit, subscribing to the room she just created.") + aliceReq := sync3.Request{ + RoomSubscriptions: map[string]sync3.RoomSubscription{ + roomID: {TimelineLimit: 100}, + }, + } + res := v3.mustDoV3Request(t, aliceToken, aliceReq) + + t.Log("Alice sees the room with the expected name, with the name event and message at the end of the timeline.") + m.MatchResponse(t, res, m.MatchRoomSubscription(roomID, + m.MatchRoomName("wonderland"), + m.MatchRoomTimelineMostRecent(2, []json.RawMessage{name1, msg1}), + )) + + t.Log("Alice's poller receives a gappy sync, including a room name change, bob joining, and two messages.") + stateBlock := make([]json.RawMessage, 0) + for i := 0; i < 10; i++ { + statePiece := testutils.NewStateEvent(t, "com.example.custom", fmt.Sprintf("%d", i), alice, map[string]any{}) + stateBlock = append(stateBlock, statePiece) + } + name2 := testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]any{ + "name": "not wonderland", + }) + bobJoin := testutils.NewJoinEvent(t, bob) + stateBlock = append(stateBlock, name2, bobJoin) + + msg2 := testutils.NewMessageEvent(t, alice, "Good morning!") + msg3 := testutils.NewMessageEvent(t, alice, "That's a nice tnetennba.") + v2.queueResponse(aliceToken, sync2.SyncResponse{ + Rooms: sync2.SyncRoomsResponse{ + Join: map[string]sync2.SyncV2JoinResponse{ + roomID: { + State: sync2.EventsResponse{ + Events: stateBlock, + }, + Timeline: sync2.TimelineResponse{ + Events: []json.RawMessage{msg2, msg3}, + Limited: true, + PrevBatch: "dummyPrevBatch", + }, + }, + }, + }, + }) + v2.waitUntilEmpty(t, aliceToken) + + t.Log("Alice syncs. The server should close her long-polling session.") + _, respBytes, statusCode := v3.doV3Request(t, context.Background(), aliceToken, res.Pos, sync3.Request{}) + assertUnknownPos(t, respBytes, statusCode) + + t.Log("Alice sliding syncs from scratch. She should see the two most recent message in the timeline only. The room name should have changed too.") + res = v3.mustDoV3Request(t, aliceToken, aliceReq) + m.MatchResponse(t, res, m.MatchRoomSubscription(roomID, + m.MatchRoomName("not wonderland"), + // In particular, we shouldn't see state here because it's not part of the timeline. + // Nor should we see msg1, as that comes before a gap. + m.MatchRoomTimeline([]json.RawMessage{msg2, msg3}), + )) +} + +// Right, this has turned out to be very involved. This test has three varying +// parameters: +// - Bert's initial membership (in 3 below), +// - his final membership in (5), and +// - whether his sync in (6) is initial or long-polling ("live"). +// +// The test: +// 1. Registers two users Ana and Bert. +// 2. Has Ana create a public room. +// 3. Sets an initial membership for Bert in that room. +// 4. Sliding syncs for Bert, if he will live-sync in (6) below. +// 5. Gives Ana's poller a gappy poll in which Bert's membership changes. +// 6. Has Bert do a sliding sync. +// 7. Ana invites Bert to a DM. +// +// We perform the following assertions: +// - After (3), Ana sees her membership, Bert's initial membership, appropriate +// join and invite counts, and an appropriate timeline. +// - If applicable: after (4), Bert sees his initial membership. +// - After (5), Ana's connection is closed. When opening a new one, she sees her +// membership, Bert's new membership, and the post-gap timeline. +// - After (6), Bert's connection is closed if he was expecting a live update. +// - After (6), Bert sees his new membership (if there is anything to see). +// - After (7), Bert sees the DM invite. +// +// Remarks: +// - Use a per-test Ana and Bert here so we don't clash with the global constants +// alice and bob. +// - We're feeding all this information in via Ana's poller to check that stuff +// propagates from her poller to Bert's client. However, when Bob's membership is +// "invite" we need to directly send the invite to his poller. +// - Step (7) serves as a sentinel to prove that the proxy has processed (5) in the +// case where there is nothing for Bert to see in (6), e.g. a preemptive ban or +// an unban during the gap. +// - Testing all the membership transitions is likely overkill. But it was useful +// for finding edge cases in the proxy's assumptions at first, before we decided to +// nuke conns and userCaches and start from scratch. +func TestClientsSeeMembershipTransitionsInGappyPolls(t *testing.T) { + pqString := testutils.PrepareDBConnectionString() + v2 := runTestV2Server(t) + // TODO remove this? Otherwise running tests is sloooooow + v2.timeToWaitForV2Response /= 20 + defer v2.close() + v3 := runTestServer(t, v2, pqString) + defer v3.close() + + type testcase struct { + // Inputs + beforeMembership string + afterMembership string + viaLiveUpdate bool + // Scratch space + id string + ana string + anaToken string + bert string + bertToken string + publicRoomID string // room that will receive gappy state + dmRoomID string // DM between ana and bert, used to send a sentinel message + } + + var tcs []testcase + + transitions := map[string][]string{ + // before: {possible after} + // https://spec.matrix.org/v1.8/client-server-api/#room-membership for the list of allowed transitions + "none": {"ban", "invite", "join", "leave"}, + "invite": {"ban", "join", "leave"}, + // Note: can also join->join here e.g. for displayname change, but will do that in a separate test + "join": {"ban", "leave"}, + "leave": {"ban", "invite", "join"}, + "ban": {"leave"}, + } + for before, afterOptions := range transitions { + for _, after := range afterOptions { + for _, live := range []bool{true, false} { + idStr := fmt.Sprintf("%s-%s", before, after) + if live { + idStr += "-live" + } + + tc := testcase{ + beforeMembership: before, + afterMembership: after, + viaLiveUpdate: live, + id: idStr, + publicRoomID: fmt.Sprintf("!%s-public", idStr), + dmRoomID: fmt.Sprintf("!%s-dm", idStr), + // Using ana and bert to stop myself from pulling in package-level constants alice and bob + ana: fmt.Sprintf("@ana-%s:localhost", idStr), + bert: fmt.Sprintf("@bert-%s:localhost", idStr), + } + tc.anaToken = tc.ana + "_token" + tc.bertToken = tc.bert + "_token" + tcs = append(tcs, tc) + } + } + } + + ssRequest := sync3.Request{ + Lists: map[string]sync3.RequestList{ + "a": { + Ranges: sync3.SliceRanges{{0, 10}}, + RoomSubscription: sync3.RoomSubscription{ + RequiredState: [][2]string{{"m.room.member", "*"}}, + TimelineLimit: 20, + }, + }, + }, + } + + setup := func(t *testing.T, tc testcase) (publicEvents []json.RawMessage, anaMembership json.RawMessage, anaRes *sync3.Response) { + // 1. Register two users Ana and Bert. + v2.addAccount(t, tc.ana, tc.anaToken) + v2.addAccount(t, tc.bert, tc.bertToken) + + // 2. Have Ana create a public room. + t.Log("Ana creates a public room.") + publicEvents = createRoomState(t, tc.ana, time.Now()) + for _, ev := range publicEvents { + parsed := gjson.ParseBytes(ev) + if parsed.Get("type").Str == "m.room.member" && parsed.Get("state_key").Str == tc.ana { + anaMembership = ev + break + } + } + + // 3. Set an initial membership for Bert. + var wantJoinCount int + var wantInviteCount int + var bertMembership json.RawMessage + + switch tc.beforeMembership { + case "none": + t.Log("Bert has no membership in the room.") + wantJoinCount = 1 + wantInviteCount = 0 + case "invite": + t.Log("Bert is invited.") + bertMembership = testutils.NewStateEvent(t, "m.room.member", tc.bert, tc.ana, map[string]any{"membership": "invite"}) + wantJoinCount = 1 + wantInviteCount = 1 + case "join": + t.Log("Bert joins the room.") + bertMembership = testutils.NewJoinEvent(t, tc.bert) + wantJoinCount = 2 + wantInviteCount = 0 + case "leave": + t.Log("Bert is pre-emptively kicked.") + bertMembership = testutils.NewStateEvent(t, "m.room.member", tc.bert, tc.ana, map[string]any{"membership": "leave"}) + wantJoinCount = 1 + wantInviteCount = 0 + case "ban": + t.Log("Bert is banned.") + bertMembership = testutils.NewStateEvent(t, "m.room.member", tc.bert, tc.ana, map[string]any{"membership": "ban"}) + wantJoinCount = 1 + wantInviteCount = 0 + default: + panic(fmt.Errorf("unknown beforeMembership %s", tc.beforeMembership)) + } + if len(bertMembership) > 0 { + publicEvents = append(publicEvents, bertMembership) + } + + t.Log("Ana's poller sees the public room for the first time.") + v2.queueResponse(tc.anaToken, sync2.SyncResponse{ + Rooms: sync2.SyncRoomsResponse{ + Join: map[string]sync2.SyncV2JoinResponse{ + tc.publicRoomID: { + Timeline: sync2.TimelineResponse{ + Events: publicEvents, + PrevBatch: "anaPublicPrevBatch1", + }, + }, + }, + }, + NextBatch: "anaSync1", + }) + + t.Log("Ana sliding syncs, requesting all room members.") + anaRes = v3.mustDoV3Request(t, tc.anaToken, ssRequest) + t.Log("She sees herself joined to both rooms, with appropriate timelines and counts.") + // Note: we only expect timeline[1:] here, not the create event. See + // https://github.com/matrix-org/sliding-sync/issues/343 + expectedMembers := []json.RawMessage{anaMembership} + if len(bertMembership) > 0 { + expectedMembers = append(expectedMembers, bertMembership) + } + m.MatchResponse(t, anaRes, + m.MatchRoomSubscription(tc.publicRoomID, + m.MatchRoomTimeline(publicEvents[1:]), + m.MatchRoomRequiredState(expectedMembers), + m.MatchJoinCount(wantJoinCount), + m.MatchInviteCount(wantInviteCount), + ), + ) + + return + } + + gappyPoll := func(t *testing.T, tc testcase, anaMembership json.RawMessage, anaRes *sync3.Response) (newMembership json.RawMessage, publicTimeline []json.RawMessage) { + t.Logf("Ana's poller gets a gappy sync response for the public room. Bert's membership is now %s, and Ana has sent 10 messages.", tc.afterMembership) + publicTimeline = make([]json.RawMessage, 10) + for i := range publicTimeline { + publicTimeline[i] = testutils.NewMessageEvent(t, tc.ana, fmt.Sprintf("hello %d", i)) + } + + switch tc.afterMembership { + case "invite": + newMembership = testutils.NewStateEvent(t, "m.room.member", tc.bert, tc.ana, map[string]any{"membership": "invite"}) + case "join": + newMembership = testutils.NewJoinEvent(t, tc.bert) + case "leave": + newMembership = testutils.NewStateEvent(t, "m.room.member", tc.bert, tc.ana, map[string]any{"membership": "leave"}) + case "ban": + newMembership = testutils.NewStateEvent(t, "m.room.member", tc.bert, tc.ana, map[string]any{"membership": "ban"}) + default: + panic(fmt.Errorf("unknown afterMembership %s", tc.afterMembership)) + } + + v2.queueResponse(tc.anaToken, sync2.SyncResponse{ + NextBatch: "ana2", + Rooms: sync2.SyncRoomsResponse{ + Join: map[string]sync2.SyncV2JoinResponse{ + tc.publicRoomID: { + State: sync2.EventsResponse{ + Events: []json.RawMessage{newMembership}, + }, + Timeline: sync2.TimelineResponse{ + Events: publicTimeline, + Limited: true, + PrevBatch: "anaPublicPrevBatch2", + }, + }, + }, + }, + }) + v2.waitUntilEmpty(t, tc.anaToken) + + if tc.afterMembership == "invite" { + t.Log("Bert's poller sees his invite.") + v2.queueResponse(tc.bertToken, sync2.SyncResponse{ + Rooms: sync2.SyncRoomsResponse{ + Invite: map[string]sync2.SyncV2InviteResponse{ + tc.publicRoomID: { + InviteState: sync2.EventsResponse{ + // TODO: this really ought to be stripped state events + Events: []json.RawMessage{anaMembership, newMembership}, + }, + }, + }}, + NextBatch: tc.bert + "_invite", + }) + } + + t.Log("Ana syncs.") + _, respBytes, statusCode := v3.doV3Request(t, context.Background(), tc.anaToken, anaRes.Pos, sync3.Request{}) + + t.Log("Her long-polling session has been closed by the server.") + assertUnknownPos(t, respBytes, statusCode) + + t.Log("Ana syncs again from scratch.") + anaRes = v3.mustDoV3Request(t, tc.anaToken, ssRequest) + + t.Log("She sees both her and Bob's membership, and the timeline from the gappy poll.") + // Note: we don't expect to see the pre-gap timeline, here because we stop at + // the first gap we see in the timeline. + m.MatchResponse(t, anaRes, m.MatchRoomSubscription(tc.publicRoomID, + m.MatchRoomRequiredState([]json.RawMessage{anaMembership, newMembership}), + m.MatchRoomTimeline(publicTimeline), + )) + return + } + + for _, tc := range tcs { + t.Run(tc.id, func(t *testing.T) { + // 1--3: Register users, create public room, set Bert's membership. + publicEvents, anaMembership, anaRes := setup(t, tc) + defer func() { + // Cleanup these users once we're done with them. This helps stop log spam when debugging. + v2.invalidateTokenImmediately(tc.anaToken) + v2.invalidateTokenImmediately(tc.bertToken) + }() + + // Ensure the proxy considers Bert to already be polling. In particular, if + // Bert is initially invited, make sure his poller sees the invite. + if tc.beforeMembership == "invite" { + t.Log("Bert's poller sees his invite.") + v2.queueResponse(tc.bertToken, sync2.SyncResponse{ + Rooms: sync2.SyncRoomsResponse{ + Invite: map[string]sync2.SyncV2InviteResponse{ + tc.publicRoomID: { + InviteState: sync2.EventsResponse{ + // TODO: this really ought to be stripped state events + Events: publicEvents, + }, + }, + }}, + NextBatch: tc.bert + "_invite", + }) + } else { + t.Log("Queue up an empty poller response for Bert.") + v2.queueResponse(tc.bertToken, sync2.SyncResponse{ + NextBatch: tc.bert + "_empty_sync", + }) + } + t.Log("Bert makes a dummy request with a different connection ID, to ensure his poller has started.") + v3.mustDoV3Request(t, tc.bertToken, sync3.Request{ + ConnID: "bert-dummy-conn", + }) + + var bertRes *sync3.Response + // 4: sliding sync for Bert, if he will live-sync in (6) below. + if tc.viaLiveUpdate { + t.Log("Bert sliding syncs.") + bertRes = v3.mustDoV3Request(t, tc.bertToken, ssRequest) + + // Bert will see the entire history of these rooms, so there shouldn't be any prev batch tokens. + expectedSubscriptions := map[string][]m.RoomMatcher{} + switch tc.beforeMembership { + case "invite": + t.Log("Bert sees his invite.") + expectedSubscriptions[tc.publicRoomID] = []m.RoomMatcher{ + m.MatchRoomHasInviteState(), + m.MatchInviteCount(1), + m.MatchJoinCount(1), + m.MatchRoomPrevBatch(""), + } + case "join": + t.Log("Bert sees his join.") + expectedSubscriptions[tc.publicRoomID] = []m.RoomMatcher{ + m.MatchRoomLacksInviteState(), + m.MatchInviteCount(0), + m.MatchJoinCount(2), + m.MatchRoomPrevBatch(""), + } + case "none": + fallthrough + case "leave": + fallthrough + case "ban": + t.Log("Bert does not see the room.") + default: + panic(fmt.Errorf("unknown beforeMembership %s", tc.beforeMembership)) + } + m.MatchResponse(t, bertRes, m.MatchRoomSubscriptionsStrict(expectedSubscriptions)) + } + + // 5: Ana receives a gappy poll, plus a sentinel in her DM with Bert. + newMembership, publicTimeline := gappyPoll(t, tc, anaMembership, anaRes) + + // 6: Bert sliding syncs. + if tc.viaLiveUpdate { + wasInvolvedInRoom := tc.beforeMembership == "join" || tc.beforeMembership == "invite" + if wasInvolvedInRoom { + t.Log("Bert makes an incremental sliding sync.") + _, respBytes, statusCode := v3.doV3Request(t, context.Background(), tc.bertToken, bertRes.Pos, ssRequest) + assertUnknownPos(t, respBytes, statusCode) + } + } else { + t.Log("Queue up an empty poller response for Bert. so the proxy will consider him to be polling.") + v2.queueResponse(tc.bertToken, sync2.SyncResponse{ + NextBatch: tc.bert + "_empty_sync", + }) + } + + t.Log("Bert makes new sliding sync connection.") + bertRes = v3.mustDoV3Request(t, tc.bertToken, ssRequest) + + // Work out what Bert should see. + respMatchers := []m.RespMatcher{} + + switch tc.afterMembership { + case "invite": + t.Log("Bert should see his invite.") + respMatchers = append(respMatchers, + m.MatchList("a", m.MatchV3Count(1)), + m.MatchRoomSubscription(tc.publicRoomID, + m.MatchRoomHasInviteState(), + m.MatchInviteCount(1), + m.MatchJoinCount(1), + )) + case "join": + t.Log("Bert should see himself joined to the room, and Alice's messages.") + respMatchers = append(respMatchers, + m.MatchList("a", m.MatchV3Count(1)), + m.MatchRoomSubscription(tc.publicRoomID, + m.MatchRoomLacksInviteState(), + m.MatchRoomRequiredState([]json.RawMessage{anaMembership, newMembership}), + m.MatchInviteCount(0), + m.MatchJoinCount(2), + m.MatchRoomTimelineMostRecent(len(publicTimeline), publicTimeline), + m.MatchRoomPrevBatch("anaPublicPrevBatch2"), + )) + case "leave": + fallthrough + case "ban": + respMatchers = append(respMatchers, m.MatchList("a", m.MatchV3Count(0))) + // Any prior connection has been closed by the server, so Bert won't see + // a transition here. + t.Logf("Bob shouldn't see his %s (membership was: %s)", tc.afterMembership, tc.beforeMembership) + respMatchers = append(respMatchers, m.MatchRoomSubscriptionsStrict(nil)) + default: + panic(fmt.Errorf("unknown afterMembership %s", tc.afterMembership)) + } + + m.MatchResponse(t, bertRes, respMatchers...) + + // 7: Ana invites Bert to a DM. He accepts. + // This is a sentinel which proves the proxy has processed the gappy poll + // properly in the situations where there's nothing for Bert to see in his + // second sync, e.g. ban -> leave (an unban). + t.Log("Ana invites Bert to a DM. He accepts.") + bertDMJoin := testutils.NewJoinEvent(t, tc.bert) + dmTimeline := append( + createRoomState(t, tc.ana, time.Now()), + testutils.NewStateEvent(t, "m.room.member", tc.bert, tc.ana, map[string]any{"membership": "invite"}), + bertDMJoin, + ) + v2.queueResponse(tc.anaToken, sync2.SyncResponse{ + NextBatch: "ana3", + Rooms: sync2.SyncRoomsResponse{ + Join: map[string]sync2.SyncV2JoinResponse{ + tc.dmRoomID: { + Timeline: sync2.TimelineResponse{ + Events: dmTimeline, + PrevBatch: "anaDM", + }, + }, + }, + }, + }) + v2.waitUntilEmpty(t, tc.anaToken) + + t.Log("Bert sliding syncs") + bertRes = v3.mustDoV3RequestWithPos(t, tc.bertToken, bertRes.Pos, ssRequest) + + t.Log("Bert sees his join to the DM.") + m.MatchResponse(t, bertRes, m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{ + tc.dmRoomID: {m.MatchRoomLacksInviteState(), m.MatchRoomTimelineMostRecent(1, []json.RawMessage{bertDMJoin})}, + })) + }) + } +} + +// This is a minimal version of the test above, which is helpful for debugging (because +// the above test is a monstrosity---apologies to the reader.) +func TestTimelineAfterRequestingStateAfterGappyPoll(t *testing.T) { + pqString := testutils.PrepareDBConnectionString() + v2 := runTestV2Server(t) + defer v2.close() + v3 := runTestServer(t, v2, pqString) + defer v3.close() + + alice := "alice" + aliceToken := "alicetoken" + bob := "bob" + roomID := "!unimportant" + + v2.addAccount(t, alice, aliceToken) + + t.Log("alice creates a public room.") + timeline1 := createRoomState(t, alice, time.Now()) + var aliceMembership json.RawMessage + for _, ev := range timeline1 { + parsed := gjson.ParseBytes(ev) + if parsed.Get("type").Str == "m.room.member" && parsed.Get("state_key").Str == alice { + aliceMembership = ev + break + } + } + if len(aliceMembership) == 0 { + t.Fatal("Initial timeline did not have a membership for Alice") + } + + v2.queueResponse(aliceToken, sync2.SyncResponse{ + Rooms: sync2.SyncRoomsResponse{ + Join: map[string]sync2.SyncV2JoinResponse{ + roomID: { + Timeline: sync2.TimelineResponse{ + Events: timeline1, + PrevBatch: "alicePublicPrevBatch1", + }, + }, + }, + }, + NextBatch: "aliceSync1", + }) + + t.Log("alice sliding syncs, requesting all memberships in state.") + aliceReq := sync3.Request{ + RoomSubscriptions: map[string]sync3.RoomSubscription{ + roomID: { + TimelineLimit: 20, + RequiredState: [][2]string{{"m.room.member", "*"}}, + }, + }, + } + aliceRes := v3.mustDoV3Request(t, aliceToken, aliceReq) + + t.Log("She sees herself joined to her room, with an appropriate timeline.") + // Note: we only expect timeline1[1:] here, excluding the create event. See + // https://github.com/matrix-org/sliding-sync/issues/343 + m.MatchResponse(t, aliceRes, + m.LogResponse(t), + m.MatchRoomSubscription(roomID, + m.MatchRoomRequiredState([]json.RawMessage{aliceMembership}), + m.MatchRoomTimeline(timeline1[1:])), + ) + + t.Logf("Alice's poller gets a gappy sync response for the public room. bob's membership is now join, and alice has sent 10 messages.") + timeline2 := make([]json.RawMessage, 10) + for i := range timeline2 { + timeline2[i] = testutils.NewMessageEvent(t, alice, fmt.Sprintf("hello %d", i)) + } + + bobMembership := testutils.NewJoinEvent(t, bob) + + v2.queueResponse(aliceToken, sync2.SyncResponse{ + NextBatch: "alice2", + Rooms: sync2.SyncRoomsResponse{ + Join: map[string]sync2.SyncV2JoinResponse{ + roomID: { + State: sync2.EventsResponse{ + Events: []json.RawMessage{bobMembership}, + }, + Timeline: sync2.TimelineResponse{ + Events: timeline2, + Limited: true, + PrevBatch: "alicePublicPrevBatch2", + }, + }, + }, + }, + }) + v2.waitUntilEmpty(t, aliceToken) + + t.Log("Alice does an incremental sliding sync.") + _, respBytes, statusCode := v3.doV3Request(t, context.Background(), aliceToken, aliceRes.Pos, sync3.Request{}) + + t.Log("Her long-polling session has been closed by the server.") + assertUnknownPos(t, respBytes, statusCode) + + t.Log("Alice syncs again from scratch.") + aliceRes = v3.mustDoV3Request(t, aliceToken, aliceReq) + + t.Log("She sees both her and Bob's membership, and the timeline from the gappy poll.") + // Note: we don't expect to see timeline1 here because we stop at the first gap we + // see in the timeline. + m.MatchResponse(t, aliceRes, m.MatchRoomSubscription(roomID, + m.MatchRoomRequiredState([]json.RawMessage{aliceMembership, bobMembership}), + m.MatchRoomTimeline(timeline2), + )) +} + +func assertUnknownPos(t *testing.T, respBytes []byte, statusCode int) { + if statusCode != http.StatusBadRequest { + t.Errorf("Got status %d, expected %d", statusCode, http.StatusBadRequest) + } + if errcode := gjson.GetBytes(respBytes, "errcode").Str; errcode != "M_UNKNOWN_POS" { + t.Errorf("Got errcode %s, expected %s", errcode, "M_UNKNOWN_POS") + } +} From c239cacc8309e331172011a5f5c9304271b77c88 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 3 Nov 2023 15:41:09 +0000 Subject: [PATCH 03/12] Initialise: handle gappy polls and ditch prependStateEvents --- state/accumulator.go | 271 ++++++++++++++++++++++++-------------- state/accumulator_test.go | 62 ++++++--- sync2/handler2/handler.go | 12 +- sync2/poller.go | 30 +---- sync2/poller_test.go | 14 +- 5 files changed, 243 insertions(+), 146 deletions(-) diff --git a/state/accumulator.go b/state/accumulator.go index bda2640f..46844fc9 100644 --- a/state/accumulator.go +++ b/state/accumulator.go @@ -141,69 +141,62 @@ type InitialiseResult struct { // AddedEvents is true iff this call to Initialise added new state events to the DB. AddedEvents bool // SnapshotID is the ID of the snapshot which incorporates all added events. - // It has no meaning if AddedEvents is False. + // It has no meaning if AddedEvents is false. SnapshotID int64 - // PrependTimelineEvents is empty if the room was not initialised prior to this call. - // Otherwise, it is an order-preserving subset of the `state` argument to Initialise - // containing all events that were not persisted prior to the Initialise call. These - // should be prepended to the room timeline by the caller. - PrependTimelineEvents []json.RawMessage + // ReplacedExistingSnapshot is true when we created a new snapshot for the room and + // there a pre-existing room snapshot. It has no meaning if AddedEvents is false. + ReplacedExistingSnapshot bool } -// Initialise starts a new sync accumulator for the given room using the given state as a baseline. +// Initialise processes the state block of a V2 sync response for a particular room. If +// the state of the room has changed, we persist any new state events and create a new +// "snapshot" of its entire state. // -// This will only take effect if this is the first time the v3 server has seen this room, and it wasn't -// possible to get all events up to the create event (e.g Matrix HQ). -// This function: -// - Stores these events -// - Sets up the current snapshot based on the state list given. +// Summary of the logic: // -// If the v3 server has seen this room before, this function -// - queries the DB to determine which state events are known to th server, -// - returns (via InitialiseResult.PrependTimelineEvents) a slice of unknown state events, +// 0. Ensure the state block is not empty. // -// and otherwise does nothing. +// 1. Capture the current snapshot ID, possibly zero. If it is zero, ensure that the +// state block contains a `create event`. +// +// 2. Insert the events. If there are no newly inserted events, bail. If there are new +// events, then the state block has definitely changed. Note: we ignore cases where +// the state has only changed to a known subset of state events (i.e in the case of +// state resets, slow pollers) as it is impossible to then reconcile that state with +// any new events, as any "catchup" state will be ignored due to the events already +// existing. +// +// 3. Fetch the current state of the room, as a map from (type, state_key) to event. +// If there is no existing state snapshot, this map is the empty map. +// If the state hasn't altered, bail. +// +// 4. Create new snapshot. Update the map from (3) with the events in `state`. +// (There is similar logic for this in Accumulate.) +// Store the snapshot. Mark the room's current state as being this snapshot. +// +// 5. Any other processing of the new state events. +// +// 6. Return an "AddedEvents" bool (if true, emit an Initialise payload) and a +// "ReplacedSnapshot" bool (if true, emit a cache invalidation payload). + func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (InitialiseResult, error) { var res InitialiseResult + var startingSnapshotID int64 + + // 0. Ensure the state block is not empty. if len(state) == 0 { return res, nil } - err := sqlutil.WithTransaction(a.db, func(txn *sqlx.Tx) error { + err := sqlutil.WithTransaction(a.db, func(txn *sqlx.Tx) (err error) { + // 1. Capture the current snapshot ID, checking for a create event if this is our first snapshot. + // Attempt to short-circuit. This has to be done inside a transaction to make sure // we don't race with multiple calls to Initialise with the same room ID. - snapshotID, err := a.roomsTable.CurrentAfterSnapshotID(txn, roomID) + startingSnapshotID, err = a.roomsTable.CurrentAfterSnapshotID(txn, roomID) if err != nil { - return fmt.Errorf("error fetching snapshot id for room %s: %s", roomID, err) - } - if snapshotID > 0 { - // Poller A has received a gappy sync v2 response with a state block, and - // we have seen this room before. If we knew for certain that there is some - // other active poller B in this room then we could safely skip this logic. - - // Log at debug for now. If we find an unknown event, we'll return it so - // that the poller can log a warning. - logger.Debug().Str("room_id", roomID).Int64("snapshot_id", snapshotID).Msg("Accumulator.Initialise called with incremental state but current snapshot already exists.") - eventIDs := make([]string, len(state)) - eventIDToRawEvent := make(map[string]json.RawMessage, len(state)) - for i := range state { - eventID := gjson.ParseBytes(state[i]).Get("event_id") - if !eventID.Exists() || eventID.Type != gjson.String { - return fmt.Errorf("Event %d lacks an event ID", i) - } - eventIDToRawEvent[eventID.Str] = state[i] - eventIDs[i] = eventID.Str - } - unknownEventIDs, err := a.eventsTable.SelectUnknownEventIDs(txn, eventIDs) - if err != nil { - return fmt.Errorf("error determing which event IDs are unknown: %s", err) - } - for unknownEventID := range unknownEventIDs { - res.PrependTimelineEvents = append(res.PrependTimelineEvents, eventIDToRawEvent[unknownEventID]) - } - return nil + return fmt.Errorf("error fetching snapshot id for room %s: %w", roomID, err) } - - // We don't have a snapshot for this room. Parse the events first. + // Start by parsing the events in the state block. events := make([]Event, len(state)) for i := range events { events[i] = Event{ @@ -214,77 +207,76 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (Initia } events = filterAndEnsureFieldsSet(events) if len(events) == 0 { - return fmt.Errorf("failed to insert events, all events were filtered out: %w", err) + return fmt.Errorf("failed to parse state block, all events were filtered out: %w", err) } - // Before proceeding further, ensure that we have "proper" state and not just a - // single stray event by looking for the create event. - hasCreate := false - for _, e := range events { - if e.Type == "m.room.create" && e.StateKey == "" { - hasCreate = true - break + if startingSnapshotID == 0 { + // Ensure that we have "proper" state and not "stray" events from Synapse. + if err = ensureStateHasCreateEvent(events); err != nil { + return err } } - if !hasCreate { - const errMsg = "cannot create first snapshot without a create event" - sentry.WithScope(func(scope *sentry.Scope) { - scope.SetContext(internal.SentryCtxKey, map[string]interface{}{ - "room_id": roomID, - "len_state": len(events), - }) - sentry.CaptureMessage(errMsg) - }) - logger.Warn(). - Str("room_id", roomID). - Int("len_state", len(events)). - Msg(errMsg) - // the HS gave us bad data so there's no point retrying => return DataError - return internal.NewDataError(errMsg) - } - // Insert the events. - eventIDToNID, err := a.eventsTable.Insert(txn, events, false) + // 2. Insert the events and determine which ones are new. + newEventIDToNID, err := a.eventsTable.Insert(txn, events, false) if err != nil { return fmt.Errorf("failed to insert events: %w", err) } - if len(eventIDToNID) == 0 { - // we don't have a current snapshot for this room but yet no events are new, - // no idea how this should be handled. - const errMsg = "Accumulator.Initialise: room has no current snapshot but also no new inserted events, doing nothing. This is probably a bug." - logger.Error().Str("room_id", roomID).Msg(errMsg) - sentry.CaptureException(fmt.Errorf(errMsg)) + if len(newEventIDToNID) == 0 { + if startingSnapshotID == 0 { + // we don't have a current snapshot for this room but yet no events are new, + // no idea how this should be handled. + const errMsg = "Accumulator.Initialise: room has no current snapshot but also no new inserted events, doing nothing. This is probably a bug." + logger.Error().Str("room_id", roomID).Msg(errMsg) + sentry.CaptureException(fmt.Errorf(errMsg)) + } + // Note: we otherwise ignore cases where the state has only changed to a + // known subset of state events (i.e in the case of state resets, slow + // pollers) as it is impossible to then reconcile that state with + // any new events, as any "catchup" state will be ignored due to the events + // already existing. return nil } - - // pull out the event NIDs we just inserted - membershipEventIDs := make(map[string]struct{}, len(events)) + newEvents := make([]Event, 0, len(newEventIDToNID)) for _, event := range events { - if event.Type == "m.room.member" { - membershipEventIDs[event.ID] = struct{}{} + newNid, isNew := newEventIDToNID[event.ID] + if isNew { + event.NID = newNid + newEvents = append(newEvents, event) } } - memberNIDs := make([]int64, 0, len(eventIDToNID)) - otherNIDs := make([]int64, 0, len(eventIDToNID)) - for evID, nid := range eventIDToNID { - if _, exists := membershipEventIDs[evID]; exists { - memberNIDs = append(memberNIDs, int64(nid)) - } else { - otherNIDs = append(otherNIDs, int64(nid)) + + // 3. Fetch the current state of the room. + var currentState stateMap + if startingSnapshotID > 0 { + currentState, err = a.stateMapAtSnapshot(txn, startingSnapshotID) + if err != nil { + return fmt.Errorf("failed to load state map: %w", err) + } + } else { + currentState = stateMap{ + Memberships: make(map[string]int64, len(events)), + Other: make(map[[2]string]int64, len(events)), } } - // Make a current snapshot + // 4. Update the map from (3) with the new events to create a new snapshot. + for _, ev := range newEvents { + currentState.Ingest(ev) + } + memberNIDs, otherNIDs := currentState.NIDs() snapshot := &SnapshotRow{ RoomID: roomID, - MembershipEvents: pq.Int64Array(memberNIDs), - OtherEvents: pq.Int64Array(otherNIDs), + MembershipEvents: memberNIDs, + OtherEvents: otherNIDs, } err = a.snapshotTable.Insert(txn, snapshot) if err != nil { return fmt.Errorf("failed to insert snapshot: %w", err) } res.AddedEvents = true + + // 5. Any other processing of new state events. latestNID := int64(0) for _, nid := range otherNIDs { if nid > latestNID { @@ -313,8 +305,16 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (Initia // will have an associated state snapshot ID on the event. // Set the snapshot ID as the current state + err = a.roomsTable.Upsert(txn, info, snapshot.SnapshotID, latestNID) + if err != nil { + return err + } + + // 6. Tell the caller what happened, so they know what payloads to emit. res.SnapshotID = snapshot.SnapshotID - return a.roomsTable.Upsert(txn, info, snapshot.SnapshotID, latestNID) + res.AddedEvents = true + res.ReplacedExistingSnapshot = startingSnapshotID > 0 + return nil }) return res, err } @@ -652,3 +652,82 @@ func (a *Accumulator) filterToNewTimelineEvents(txn *sqlx.Tx, dedupedEvents []Ev // A is seen event s[A,B,C] => s[0+1:] => [B,C] return dedupedEvents[seenIndex+1:], nil } + +func ensureStateHasCreateEvent(events []Event) error { + hasCreate := false + for _, e := range events { + if e.Type == "m.room.create" && e.StateKey == "" { + hasCreate = true + break + } + } + if !hasCreate { + const errMsg = "cannot create first snapshot without a create event" + sentry.WithScope(func(scope *sentry.Scope) { + scope.SetContext(internal.SentryCtxKey, map[string]interface{}{ + "room_id": events[0].RoomID, + "len_state": len(events), + }) + sentry.CaptureMessage(errMsg) + }) + logger.Warn(). + Str("room_id", events[0].RoomID). + Int("len_state", len(events)). + Msg(errMsg) + // the HS gave us bad data so there's no point retrying => return DataError + return internal.NewDataError(errMsg) + } + return nil +} + +type stateMap struct { + // state_key (user id) -> NID + Memberships map[string]int64 + // type, state_key -> NID + Other map[[2]string]int64 +} + +func (s *stateMap) Ingest(e Event) (replacedNID int64) { + if e.Type == "m.room.member" { + replacedNID = s.Memberships[e.StateKey] + s.Memberships[e.StateKey] = e.NID + } else { + key := [2]string{e.Type, e.StateKey} + replacedNID = s.Other[key] + s.Other[key] = e.NID + } + return +} + +func (s *stateMap) NIDs() (membershipNIDs, otherNIDs []int64) { + membershipNIDs = make([]int64, 0, len(s.Memberships)) + otherNIDs = make([]int64, 0, len(s.Other)) + for _, nid := range s.Memberships { + membershipNIDs = append(membershipNIDs, nid) + } + for _, nid := range s.Other { + otherNIDs = append(otherNIDs, nid) + } + return +} + +func (a *Accumulator) stateMapAtSnapshot(txn *sqlx.Tx, snapID int64) (stateMap, error) { + snapshot, err := a.snapshotTable.Select(txn, snapID) + if err != nil { + return stateMap{}, err + } + // pull stripped events as this may be huge (think Matrix HQ) + events, err := a.eventsTable.SelectStrippedEventsByNIDs(txn, true, append(snapshot.MembershipEvents, snapshot.OtherEvents...)) + if err != nil { + return stateMap{}, err + } + + state := stateMap{ + Memberships: make(map[string]int64, len(snapshot.MembershipEvents)), + Other: make(map[[2]string]int64, len(snapshot.OtherEvents)), + } + for _, e := range events { + state.Ingest(e) + } + return state, nil +} diff --git a/state/accumulator_test.go b/state/accumulator_test.go index 65dc2440..db358830 100644 --- a/state/accumulator_test.go +++ b/state/accumulator_test.go @@ -35,9 +35,8 @@ func TestAccumulatorInitialise(t *testing.T) { if err != nil { t.Fatalf("falied to Initialise accumulator: %s", err) } - if !res.AddedEvents { - t.Fatalf("didn't add events, wanted it to") - } + assertValue(t, "res.AddedEvents", res.AddedEvents, true) + assertValue(t, "res.ReplacedExistingSnapshot", res.ReplacedExistingSnapshot, false) txn, err := accumulator.db.Beginx() if err != nil { @@ -46,21 +45,21 @@ func TestAccumulatorInitialise(t *testing.T) { defer txn.Rollback() // There should be one snapshot on the current state - snapID, err := accumulator.roomsTable.CurrentAfterSnapshotID(txn, roomID) + snapID1, err := accumulator.roomsTable.CurrentAfterSnapshotID(txn, roomID) if err != nil { t.Fatalf("failed to select current snapshot: %s", err) } - if snapID == 0 { + if snapID1 == 0 { t.Fatalf("Initialise did not store a current snapshot") } - if snapID != res.SnapshotID { - t.Fatalf("Initialise returned wrong snapshot ID, got %v want %v", res.SnapshotID, snapID) + if snapID1 != res.SnapshotID { + t.Fatalf("Initialise returned wrong snapshot ID, got %v want %v", res.SnapshotID, snapID1) } // this snapshot should have 1 member event and 2 other events in it - row, err := accumulator.snapshotTable.Select(txn, snapID) + row, err := accumulator.snapshotTable.Select(txn, snapID1) if err != nil { - t.Fatalf("failed to select snapshot %d: %s", snapID, err) + t.Fatalf("failed to select snapshot %d: %s", snapID1, err) } if len(row.MembershipEvents) != 1 { t.Fatalf("got %d membership events, want %d in current state snapshot", len(row.MembershipEvents), 1) @@ -87,7 +86,7 @@ func TestAccumulatorInitialise(t *testing.T) { } } - // Subsequent calls do nothing and are not an error + // Subsequent calls with the same set of the events do nothing and are not an error. res, err = accumulator.Initialise(roomID, roomEvents) if err != nil { t.Fatalf("falied to Initialise accumulator: %s", err) @@ -95,6 +94,37 @@ func TestAccumulatorInitialise(t *testing.T) { if res.AddedEvents { t.Fatalf("added events when it shouldn't have") } + + // Subsequent calls with a subset of events do nothing and are not an error + res, err = accumulator.Initialise(roomID, roomEvents[:2]) + if err != nil { + t.Fatalf("falied to Initialise accumulator: %s", err) + } + if res.AddedEvents { + t.Fatalf("added events when it shouldn't have") + } + + // Subsequent calls with at least one new event expand or replace existing state. + // C, D, E + roomEvents2 := append(roomEvents[2:3], + []byte(`{"event_id":"D", "type":"m.room.topic", "state_key":"", "content":{"topic":"Dr Rick Dagless MD"}}`), + []byte(`{"event_id":"E", "type":"m.room.member", "state_key":"@me:localhost", "content":{"membership":"join", "displayname": "Garth""}}`), + ) + res, err = accumulator.Initialise(roomID, roomEvents2) + assertNoError(t, err) + assertValue(t, "res.AddedEvents", res.AddedEvents, true) + assertValue(t, "res.ReplacedExistingSnapshot", res.ReplacedExistingSnapshot, true) + + snapID2, err := accumulator.roomsTable.CurrentAfterSnapshotID(txn, roomID) + assertNoError(t, err) + if snapID2 == snapID1 || snapID2 == 0 { + t.Errorf("Expected snapID2 (%d) to be neither snapID1 (%d) nor 0", snapID2, snapID1) + } + + row, err = accumulator.snapshotTable.Select(txn, snapID2) + assertNoError(t, err) + assertValue(t, "len(row.MembershipEvents)", len(row.MembershipEvents), 1) + assertValue(t, "len(row.OtherEvents)", len(row.OtherEvents), 3) } // Test that an unknown room shouldn't initialise if given state without a create event. @@ -115,9 +145,9 @@ func TestAccumulatorInitialiseBadInputs(t *testing.T) { func TestAccumulatorAccumulate(t *testing.T) { roomID := "!TestAccumulatorAccumulate:localhost" roomEvents := []json.RawMessage{ - []byte(`{"event_id":"D", "type":"m.room.create", "state_key":"", "content":{"creator":"@me:localhost"}}`), - []byte(`{"event_id":"E", "type":"m.room.member", "state_key":"@me:localhost", "content":{"membership":"join"}}`), - []byte(`{"event_id":"F", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`), + []byte(`{"event_id":"G", "type":"m.room.create", "state_key":"", "content":{"creator":"@me:localhost"}}`), + []byte(`{"event_id":"H", "type":"m.room.member", "state_key":"@me:localhost", "content":{"membership":"join"}}`), + []byte(`{"event_id":"I", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`), } db, close := connectToDB(t) defer close() @@ -130,11 +160,11 @@ func TestAccumulatorAccumulate(t *testing.T) { // accumulate new state makes a new snapshot and removes the old snapshot newEvents := []json.RawMessage{ // non-state event does nothing - []byte(`{"event_id":"G", "type":"m.room.message","content":{"body":"Hello World","msgtype":"m.text"}}`), + []byte(`{"event_id":"J", "type":"m.room.message","content":{"body":"Hello World","msgtype":"m.text"}}`), // join_rules should clobber the one from initialise - []byte(`{"event_id":"H", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`), + []byte(`{"event_id":"K", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`), // new state event should be added to the snapshot - []byte(`{"event_id":"I", "type":"m.room.history_visibility", "state_key":"", "content":{"visibility":"public"}}`), + []byte(`{"event_id":"L", "type":"m.room.history_visibility", "state_key":"", "content":{"visibility":"public"}}`), } var result AccumulateResult err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error { diff --git a/sync2/handler2/handler.go b/sync2/handler2/handler.go index 66ccf4c2..f336f30a 100644 --- a/sync2/handler2/handler.go +++ b/sync2/handler2/handler.go @@ -370,20 +370,24 @@ func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID strin return nil } -func (h *Handler) Initialise(ctx context.Context, roomID string, state []json.RawMessage) ([]json.RawMessage, error) { +func (h *Handler) Initialise(ctx context.Context, roomID string, state []json.RawMessage) error { res, err := h.Store.Initialise(roomID, state) if err != nil { logger.Err(err).Int("state", len(state)).Str("room", roomID).Msg("V2: failed to initialise room") internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err) - return nil, err + return err } - if res.AddedEvents { + if res.ReplacedExistingSnapshot { + h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2InvalidateRoom{ + RoomID: roomID, + }) + } else if res.AddedEvents { h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2Initialise{ RoomID: roomID, SnapshotNID: res.SnapshotID, }) } - return res.PrependTimelineEvents, nil + return nil } func (h *Handler) SetTyping(ctx context.Context, pollerID sync2.PollerID, roomID string, ephEvent json.RawMessage) { diff --git a/sync2/poller.go b/sync2/poller.go index cc456a6d..b3a80bb9 100644 --- a/sync2/poller.go +++ b/sync2/poller.go @@ -40,7 +40,7 @@ type V2DataReceiver interface { // Initialise the room, if it hasn't been already. This means the state section of the v2 response. // If given a state delta from an incremental sync, returns the slice of all state events unknown to the DB. // Return an error to stop the since token advancing. - Initialise(ctx context.Context, roomID string, state []json.RawMessage) ([]json.RawMessage, error) // snapshot ID? + Initialise(ctx context.Context, roomID string, state []json.RawMessage) error // snapshot ID? // SetTyping indicates which users are typing. SetTyping(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage) // Sent when there is a new receipt @@ -326,11 +326,11 @@ func (h *PollerMap) Accumulate(ctx context.Context, userID, deviceID, roomID str wg.Wait() return } -func (h *PollerMap) Initialise(ctx context.Context, roomID string, state []json.RawMessage) (result []json.RawMessage, err error) { +func (h *PollerMap) Initialise(ctx context.Context, roomID string, state []json.RawMessage) (err error) { var wg sync.WaitGroup wg.Add(1) h.executor <- func() { - result, err = h.callbacks.Initialise(ctx, roomID, state) + err = h.callbacks.Initialise(ctx, roomID, state) wg.Done() } wg.Wait() @@ -789,30 +789,14 @@ func (p *poller) parseRoomsResponse(ctx context.Context, res *SyncResponse) erro for roomID, roomData := range res.Rooms.Join { if len(roomData.State.Events) > 0 { stateCalls++ - prependStateEvents, err := p.receiver.Initialise(ctx, roomID, roomData.State.Events) + if roomData.Timeline.Limited { + p.trackGappyStateSize(len(roomData.State.Events)) + } + err := p.receiver.Initialise(ctx, roomID, roomData.State.Events) if err != nil { lastErrs = append(lastErrs, fmt.Errorf("Initialise[%s]: %w", roomID, err)) continue } - if len(prependStateEvents) > 0 { - // The poller has just learned of these state events due to an - // incremental poller sync; we must have missed the opportunity to see - // these down /sync in a timeline. As a workaround, inject these into - // the timeline now so that future events are received under the - // correct room state. - const warnMsg = "parseRoomsResponse: prepending state events to timeline after gappy poll" - logger.Warn().Str("room_id", roomID).Int("prependStateEvents", len(prependStateEvents)).Msg(warnMsg) - hub := internal.GetSentryHubFromContextOrDefault(ctx) - hub.WithScope(func(scope *sentry.Scope) { - scope.SetContext(internal.SentryCtxKey, map[string]interface{}{ - "room_id": roomID, - "num_prepend_state_events": len(prependStateEvents), - }) - hub.CaptureMessage(warnMsg) - }) - p.trackGappyStateSize(len(prependStateEvents)) - roomData.Timeline.Events = append(prependStateEvents, roomData.Timeline.Events...) - } } // process typing/receipts before events so we seed the caches correctly for when we return the room for _, ephEvent := range roomData.Ephemeral.Events { diff --git a/sync2/poller_test.go b/sync2/poller_test.go index ccaa6b32..1ce7a745 100644 --- a/sync2/poller_test.go +++ b/sync2/poller_test.go @@ -830,8 +830,8 @@ func TestPollerResendsOnCallbackError(t *testing.T) { // generate a receiver which errors for the right callback generateReceiver: func() V2DataReceiver { return &overrideDataReceiver{ - initialise: func(ctx context.Context, roomID string, state []json.RawMessage) ([]json.RawMessage, error) { - return nil, fmt.Errorf("initialise error") + initialise: func(ctx context.Context, roomID string, state []json.RawMessage) error { + return fmt.Errorf("initialise error") }, } }, @@ -1273,7 +1273,7 @@ func (a *mockDataReceiver) Accumulate(ctx context.Context, userID, deviceID, roo a.timelines[roomID] = append(a.timelines[roomID], timeline.Events...) return nil } -func (a *mockDataReceiver) Initialise(ctx context.Context, roomID string, state []json.RawMessage) ([]json.RawMessage, error) { +func (a *mockDataReceiver) Initialise(ctx context.Context, roomID string, state []json.RawMessage) error { a.states[roomID] = state if a.incomingProcess != nil { a.incomingProcess <- struct{}{} @@ -1283,7 +1283,7 @@ func (a *mockDataReceiver) Initialise(ctx context.Context, roomID string, state } // The return value is a list of unknown state events to be prepended to the room // timeline. Untested here---return nil for now. - return nil, nil + return nil } func (s *mockDataReceiver) UpdateDeviceSince(ctx context.Context, userID, deviceID, since string) { s.mu.Lock() @@ -1296,7 +1296,7 @@ func (s *mockDataReceiver) UpdateDeviceSince(ctx context.Context, userID, device type overrideDataReceiver struct { accumulate func(ctx context.Context, userID, deviceID, roomID, prevBatch string, timeline []json.RawMessage) error - initialise func(ctx context.Context, roomID string, state []json.RawMessage) ([]json.RawMessage, error) + initialise func(ctx context.Context, roomID string, state []json.RawMessage) error setTyping func(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage) updateDeviceSince func(ctx context.Context, userID, deviceID, since string) addToDeviceMessages func(ctx context.Context, userID, deviceID string, msgs []json.RawMessage) error @@ -1316,9 +1316,9 @@ func (s *overrideDataReceiver) Accumulate(ctx context.Context, userID, deviceID, } return s.accumulate(ctx, userID, deviceID, roomID, timeline.PrevBatch, timeline.Events) } -func (s *overrideDataReceiver) Initialise(ctx context.Context, roomID string, state []json.RawMessage) ([]json.RawMessage, error) { +func (s *overrideDataReceiver) Initialise(ctx context.Context, roomID string, state []json.RawMessage) error { if s.initialise == nil { - return nil, nil + return nil } return s.initialise(ctx, roomID, state) } From eb1ada2f95117332ba702ae7c089283377442c10 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 3 Nov 2023 15:43:42 +0000 Subject: [PATCH 04/12] Remove OnInvalidateRoom from Reciever interface --- sync3/caches/user.go | 7 ------- sync3/dispatcher.go | 1 - 2 files changed, 8 deletions(-) diff --git a/sync3/caches/user.go b/sync3/caches/user.go index ca2bc429..0db8657e 100644 --- a/sync3/caches/user.go +++ b/sync3/caches/user.go @@ -771,10 +771,3 @@ func (u *UserCache) ShouldIgnore(userID string) bool { _, ignored := u.ignoredUsers[userID] return ignored } - -func (u *UserCache) OnInvalidateRoom(ctx context.Context, roomID string) { - // Nothing for now. In UserRoomData the fields dependant on room state are - // IsDM, IsInvite, HasLeft, Invite, CanonicalisedName, ResolvedAvatarURL, Spaces. - // Not clear to me if we need to reload these or if we will inherit any changes from - // the global cache. -} diff --git a/sync3/dispatcher.go b/sync3/dispatcher.go index 9fd21de1..52119c6d 100644 --- a/sync3/dispatcher.go +++ b/sync3/dispatcher.go @@ -24,7 +24,6 @@ type Receiver interface { OnNewEvent(ctx context.Context, event *caches.EventData) OnReceipt(ctx context.Context, receipt internal.Receipt) OnEphemeralEvent(ctx context.Context, roomID string, ephEvent json.RawMessage) - OnInvalidateRoom(ctx context.Context, roomID string) // OnRegistered is called after a successful call to Dispatcher.Register OnRegistered(ctx context.Context) error } From c6fb96ac707268e89334e7c8b17f1527df13b86d Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 3 Nov 2023 15:44:59 +0000 Subject: [PATCH 05/12] Nuke connections after a room is invalidated --- sync3/connmap.go | 13 +++++++++ sync3/dispatcher.go | 21 +++------------ sync3/handler/handler.go | 57 +++++++++++++++++++++++++++++++++++++++- 3 files changed, 72 insertions(+), 19 deletions(-) diff --git a/sync3/connmap.go b/sync3/connmap.go index 69d02c7d..575ec0b1 100644 --- a/sync3/connmap.go +++ b/sync3/connmap.go @@ -181,6 +181,19 @@ func (m *ConnMap) connIDsForDevice(userID, deviceID string) []ConnID { return connIDs } +// CloseConnsForUser closes all conns for a given user. Returns the number of conns closed. +func (m *ConnMap) CloseConnsForUser(userID string) int { + m.mu.Lock() + defer m.mu.Unlock() + conns := m.userIDToConn[userID] + logger.Trace().Str("user", userID).Int("num_conns", len(conns)).Msg("closing all device connections due to CloseConn()") + + for _, cid := range conns { + m.cache.Remove(cid.String()) // this will fire TTL callbacks which calls closeConn + } + return len(conns) +} + func (m *ConnMap) closeConnExpires(connID string, value interface{}) { m.mu.Lock() defer m.mu.Unlock() diff --git a/sync3/dispatcher.go b/sync3/dispatcher.go index 52119c6d..fea5cbee 100644 --- a/sync3/dispatcher.go +++ b/sync3/dispatcher.go @@ -275,22 +275,7 @@ func (d *Dispatcher) notifyListeners(ctx context.Context, ed *caches.EventData, } } -func (d *Dispatcher) OnInvalidateRoom(ctx context.Context, roomID string) { - // First dispatch to the global cache. - receiver, ok := d.userToReceiver[DispatcherAllUsers] - if !ok { - logger.Error().Msgf("No receiver for global cache") - } - receiver.OnInvalidateRoom(ctx, roomID) - - // Then dispatch to any users who are joined to that room. - joinedUsers, _ := d.jrt.JoinedUsersForRoom(roomID, nil) - d.userToReceiverMu.RLock() - defer d.userToReceiverMu.RUnlock() - for _, userID := range joinedUsers { - receiver = d.userToReceiver[userID] - if receiver != nil { - receiver.OnInvalidateRoom(ctx, roomID) - } - } +func (d *Dispatcher) OnInvalidateRoom(roomID string, joins, invites []string) { + // Reset the joined room tracker. + d.jrt.ReloadMembershipsForRoom(roomID, joins, invites) } diff --git a/sync3/handler/handler.go b/sync3/handler/handler.go index 03a612d9..49a5dff0 100644 --- a/sync3/handler/handler.go +++ b/sync3/handler/handler.go @@ -63,6 +63,11 @@ type SyncLiveHandler struct { setupHistVec *prometheus.HistogramVec histVec *prometheus.HistogramVec slowReqs prometheus.Counter + // destroyedConns is the number of connections that have been destoryed after + // a room invalidation payload. + // TODO: could make this a CounterVec labelled by reason, to track expiry due + // to update buffer filling, expiry due to inactivity, etc. + destroyedConns prometheus.Counter } func NewSync3Handler( @@ -139,6 +144,9 @@ func (h *SyncLiveHandler) Teardown() { if h.slowReqs != nil { prometheus.Unregister(h.slowReqs) } + if h.destroyedConns != nil { + prometheus.Unregister(h.destroyedConns) + } } func (h *SyncLiveHandler) addPrometheusMetrics() { @@ -162,9 +170,17 @@ func (h *SyncLiveHandler) addPrometheusMetrics() { Name: "slow_requests", Help: "Counter of slow (>=50s) requests, initial or otherwise.", }) + h.destroyedConns = prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "sliding_sync", + Subsystem: "api", + Name: "destroyed_conns", + Help: "Counter of conns that were destroyed.", + }) + prometheus.MustRegister(h.setupHistVec) prometheus.MustRegister(h.histVec) prometheus.MustRegister(h.slowReqs) + prometheus.MustRegister(h.destroyedConns) } func (h *SyncLiveHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { @@ -818,7 +834,46 @@ func (h *SyncLiveHandler) OnInvalidateRoom(p *pubsub.V2InvalidateRoom) { ctx, task := internal.StartTask(context.Background(), "OnInvalidateRoom") defer task.End() - h.Dispatcher.OnInvalidateRoom(ctx, p.RoomID) + // 1. Reload the global cache. + h.GlobalCache.OnInvalidateRoom(ctx, p.RoomID) + + // Work out who is affected. + joins, invites, leaves, err := h.Storage.FetchMemberships(p.RoomID) + involvedUsers := make([]string, 0, len(joins)+len(invites)+len(leaves)) + involvedUsers = append(involvedUsers, joins...) + involvedUsers = append(involvedUsers, invites...) + involvedUsers = append(involvedUsers, leaves...) + + // 2. Reload the joined-room tracker. + if err != nil { + hub := internal.GetSentryHubFromContextOrDefault(ctx) + hub.WithScope(func(scope *sentry.Scope) { + scope.SetContext(internal.SentryCtxKey, map[string]any{ + "room_id": p.RoomID, + }) + hub.CaptureException(err) + }) + logger.Err(err). + Str("room_id", p.RoomID). + Msg("Failed to fetch members after cache invalidation") + } + + h.Dispatcher.OnInvalidateRoom(p.RoomID, joins, invites) + + // 3. Destroy involved users' caches. + for _, userID := range involvedUsers { + h.Dispatcher.Unregister(userID) + h.userCaches.Delete(userID) + } + + // 4. Destroy involved users' connections. + var destroyed int + for _, userID := range involvedUsers { + destroyed += h.ConnMap.CloseConnsForUser(userID) + } + if h.destroyedConns != nil { + h.destroyedConns.Add(float64(destroyed)) + } } func parseIntFromQuery(u *url.URL, param string) (result int64, err *internal.HandlerError) { From 2044af96de90a8e88c99cbb88b3ded35ea9b1f39 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 3 Nov 2023 15:45:28 +0000 Subject: [PATCH 06/12] Comment improvements --- state/storage.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/state/storage.go b/state/storage.go index c8a61e73..563844cc 100644 --- a/state/storage.go +++ b/state/storage.go @@ -365,6 +365,8 @@ func (s *Storage) ResetMetadataState(metadata *internal.RoomMetadata) error { // For now, don't bother reloading Encrypted, PredecessorID and UpgradedRoomID. // These shouldn't be changing during a room's lifetime in normal operation. + + // We haven't updated LatestEventsByType because that's not part of the timeline. return nil } From 2052632d0c99f55051ba99e2022e21227e5598a2 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 3 Nov 2023 17:56:25 +0000 Subject: [PATCH 07/12] E2E test changes from #329. I don't fully remember the details, but may as well pluck this from the previous PR. --- tests-e2e/gappy_state_test.go | 50 +++++++++++++++++++++++++++++------ 1 file changed, 42 insertions(+), 8 deletions(-) diff --git a/tests-e2e/gappy_state_test.go b/tests-e2e/gappy_state_test.go index ce6a3fd0..209fc9b3 100644 --- a/tests-e2e/gappy_state_test.go +++ b/tests-e2e/gappy_state_test.go @@ -1,6 +1,7 @@ package syncv3_test import ( + "encoding/json" "fmt" "testing" @@ -65,10 +66,21 @@ func TestGappyState(t *testing.T) { Content: nameContent, }) - t.Log("Alice sends lots of message events (more than the poller will request in a timeline.") - var latestMessageID string - for i := 0; i < 51; i++ { - latestMessageID = alice.Unsafe_SendEventUnsynced(t, roomID, b.Event{ + t.Log("Alice sends lots of other state events.") + const numOtherState = 40 + for i := 0; i < numOtherState; i++ { + alice.Unsafe_SendEventUnsynced(t, roomID, b.Event{ + Type: "com.example.dummy", + StateKey: ptr(fmt.Sprintf("%d", i)), + Content: map[string]any{}, + }) + } + + t.Log("Alice sends a batch of message events.") + const numMessages = 20 + var lastMsgID string + for i := 0; i < numMessages; i++ { + lastMsgID = alice.Unsafe_SendEventUnsynced(t, roomID, b.Event{ Type: "m.room.message", Content: map[string]interface{}{ "msgtype": "m.text", @@ -77,28 +89,50 @@ func TestGappyState(t *testing.T) { }) } - t.Log("Alice requests an initial sliding sync on device 2.") + t.Logf("The proxy is now %d events behind the HS, which should trigger a limited sync", 1+numOtherState+numMessages) + + t.Log("Alice requests an initial sliding sync on device 2, with timeline limit big enough to see her first message at the start of the test.") syncResp = alice.SlidingSync(t, sync3.Request{ Lists: map[string]sync3.RequestList{ "a": { Ranges: [][2]int64{{0, 20}}, RoomSubscription: sync3.RoomSubscription{ - TimelineLimit: 10, + TimelineLimit: 100, }, }, }, }, ) - t.Log("She should see her latest message with the room name updated") + // We're testing here that the state events from the gappy poll are NOT injected + // into the timeline. The poll is only going to use timeline limit 1 because it's + // the first poll on a new device. See integration test for a "proper" gappy poll. + t.Log("She should see the updated room name, her most recent message, but NOT the state events in the gap nor messages from before the gap.") m.MatchResponse( t, syncResp, m.MatchRoomSubscription( roomID, m.MatchRoomName("potato"), - MatchRoomTimelineMostRecent(1, []Event{{ID: latestMessageID}}), + MatchRoomTimelineMostRecent(1, []Event{{ID: lastMsgID}}), + func(r sync3.Room) error { + for _, rawEv := range r.Timeline { + var ev Event + err := json.Unmarshal(rawEv, &ev) + if err != nil { + t.Fatal(err) + } + // Shouldn't see the state events, only messages + if ev.Type != "m.room.message" { + return fmt.Errorf("timeline contained event %s of type %s (expected m.room.message)", ev.ID, ev.Type) + } + if ev.ID == firstMessageID { + return fmt.Errorf("timeline contained first message from before the gap") + } + } + return nil + }, ), ) } From 9ad93b861ed661dffe15a99afbd89577652d774c Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 7 Nov 2023 13:28:16 +0000 Subject: [PATCH 08/12] Initialise: let `Other` state self-size --- state/accumulator.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/state/accumulator.go b/state/accumulator.go index 46844fc9..678ef4e8 100644 --- a/state/accumulator.go +++ b/state/accumulator.go @@ -255,8 +255,9 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (Initia } } else { currentState = stateMap{ + // Typically expect Other to be small, but Memberships may be large (think: Matrix HQ.) Memberships: make(map[string]int64, len(events)), - Other: make(map[[2]string]int64, len(events)), + Other: make(map[[2]string]int64), } } From 041965ffd0f85b67c2b0a7d9ace2a5f5c822d6dc Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 7 Nov 2023 13:56:12 +0000 Subject: [PATCH 09/12] Invalidate: bail if we fail to fetch members --- sync3/handler/handler.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sync3/handler/handler.go b/sync3/handler/handler.go index 49a5dff0..7dc58584 100644 --- a/sync3/handler/handler.go +++ b/sync3/handler/handler.go @@ -844,7 +844,6 @@ func (h *SyncLiveHandler) OnInvalidateRoom(p *pubsub.V2InvalidateRoom) { involvedUsers = append(involvedUsers, invites...) involvedUsers = append(involvedUsers, leaves...) - // 2. Reload the joined-room tracker. if err != nil { hub := internal.GetSentryHubFromContextOrDefault(ctx) hub.WithScope(func(scope *sentry.Scope) { @@ -856,8 +855,10 @@ func (h *SyncLiveHandler) OnInvalidateRoom(p *pubsub.V2InvalidateRoom) { logger.Err(err). Str("room_id", p.RoomID). Msg("Failed to fetch members after cache invalidation") + return } + // 2. Reload the joined-room tracker. h.Dispatcher.OnInvalidateRoom(p.RoomID, joins, invites) // 3. Destroy involved users' caches. From 78b1e5970c179802e54dd86f0fbdc54f63e37df0 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 7 Nov 2023 17:13:05 +0000 Subject: [PATCH 10/12] Batch unregister users acquire mutex once, rather than N times --- sync3/dispatcher.go | 16 ++++++++++++++-- sync3/handler/handler.go | 5 +++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/sync3/dispatcher.go b/sync3/dispatcher.go index fea5cbee..9385d1bb 100644 --- a/sync3/dispatcher.go +++ b/sync3/dispatcher.go @@ -55,10 +55,22 @@ func (d *Dispatcher) Startup(roomToJoinedUsers map[string][]string) error { return nil } -func (d *Dispatcher) Unregister(userID string) { +// UnregisterBulk accepts a slice of user IDs to unregister. The given users need not +// already be registered (in which case unregistering them is a no-op). Returns the +// list of users that were unregistered. +func (d *Dispatcher) UnregisterBulk(userIDs []string) []string { d.userToReceiverMu.Lock() defer d.userToReceiverMu.Unlock() - delete(d.userToReceiver, userID) + + unregistered := make([]string) + for _, userID := range userIDs { + _, exists := d.userToReceiver[userID] + if exists { + delete(d.userToReceiver, userID) + unregistered = append(unregistered, userID) + } + } + return unregistered } func (d *Dispatcher) Register(ctx context.Context, userID string, r Receiver) error { diff --git a/sync3/handler/handler.go b/sync3/handler/handler.go index 7dc58584..00824dc9 100644 --- a/sync3/handler/handler.go +++ b/sync3/handler/handler.go @@ -862,8 +862,9 @@ func (h *SyncLiveHandler) OnInvalidateRoom(p *pubsub.V2InvalidateRoom) { h.Dispatcher.OnInvalidateRoom(p.RoomID, joins, invites) // 3. Destroy involved users' caches. - for _, userID := range involvedUsers { - h.Dispatcher.Unregister(userID) + // We filter to only those users which had a userCache registered to receive updates. + unregistered := h.Dispatcher.UnregisterBulk(involvedUsers) + for _, userID := range unregistered { h.userCaches.Delete(userID) } From 4011e3812ae2a8940da22aa4e8e6959a6c2e931d Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 7 Nov 2023 18:33:08 +0000 Subject: [PATCH 11/12] Batch destroy conns --- sync3/connmap.go | 18 +++++++++++------- sync3/handler/handler.go | 6 ++---- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/sync3/connmap.go b/sync3/connmap.go index 575ec0b1..3803d674 100644 --- a/sync3/connmap.go +++ b/sync3/connmap.go @@ -181,17 +181,21 @@ func (m *ConnMap) connIDsForDevice(userID, deviceID string) []ConnID { return connIDs } -// CloseConnsForUser closes all conns for a given user. Returns the number of conns closed. -func (m *ConnMap) CloseConnsForUser(userID string) int { +// CloseConnsForUsers closes all conns for a given slice of users. Returns the number of +// conns closed. +func (m *ConnMap) CloseConnsForUsers(userIDs []string) (closed int) { m.mu.Lock() defer m.mu.Unlock() - conns := m.userIDToConn[userID] - logger.Trace().Str("user", userID).Int("num_conns", len(conns)).Msg("closing all device connections due to CloseConn()") + for _, userID := range userIDs { + conns := m.userIDToConn[userID] + logger.Trace().Str("user", userID).Int("num_conns", len(conns)).Msg("closing all device connections due to CloseConn()") - for _, cid := range conns { - m.cache.Remove(cid.String()) // this will fire TTL callbacks which calls closeConn + for _, conn := range conns { + m.cache.Remove(conn.String()) // this will fire TTL callbacks which calls closeConn + } + closed += len(conns) } - return len(conns) + return closed } func (m *ConnMap) closeConnExpires(connID string, value interface{}) { diff --git a/sync3/handler/handler.go b/sync3/handler/handler.go index 00824dc9..906c973b 100644 --- a/sync3/handler/handler.go +++ b/sync3/handler/handler.go @@ -869,10 +869,8 @@ func (h *SyncLiveHandler) OnInvalidateRoom(p *pubsub.V2InvalidateRoom) { } // 4. Destroy involved users' connections. - var destroyed int - for _, userID := range involvedUsers { - destroyed += h.ConnMap.CloseConnsForUser(userID) - } + // Since creating a conn creates a user cache, it is safe to loop over + destroyed := h.ConnMap.CloseConnsForUsers(unregistered) if h.destroyedConns != nil { h.destroyedConns.Add(float64(destroyed)) } From 9aa8f55507da26e5ad93fd07bf874595a62d8cc2 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 7 Nov 2023 18:43:20 +0000 Subject: [PATCH 12/12] Fix build --- sync3/dispatcher.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sync3/dispatcher.go b/sync3/dispatcher.go index 9385d1bb..a7bb24b1 100644 --- a/sync3/dispatcher.go +++ b/sync3/dispatcher.go @@ -55,6 +55,12 @@ func (d *Dispatcher) Startup(roomToJoinedUsers map[string][]string) error { return nil } +func (d *Dispatcher) Unregister(userID string) { + d.userToReceiverMu.Lock() + defer d.userToReceiverMu.Unlock() + delete(d.userToReceiver, userID) +} + // UnregisterBulk accepts a slice of user IDs to unregister. The given users need not // already be registered (in which case unregistering them is a no-op). Returns the // list of users that were unregistered. @@ -62,7 +68,7 @@ func (d *Dispatcher) UnregisterBulk(userIDs []string) []string { d.userToReceiverMu.Lock() defer d.userToReceiverMu.Unlock() - unregistered := make([]string) + unregistered := make([]string, 0) for _, userID := range userIDs { _, exists := d.userToReceiver[userID] if exists {