diff --git a/tests/federation_room_join_partial_state_test.go b/tests/federation_room_join_partial_state_test.go index a532e9ca..3fa1092a 100644 --- a/tests/federation_room_join_partial_state_test.go +++ b/tests/federation_room_join_partial_state_test.go @@ -589,12 +589,23 @@ func beginPartialStateJoin(t *testing.T, deployment *docker.Deployment, joiningU }, })) - // register a handler for /state_ids requests, which finishes fedStateIdsRequestReceivedWaiter, then + // register a handler for /state_ids requests for the most recent event, + // which finishes fedStateIdsRequestReceivedWaiter, then // waits for fedStateIdsSendResponseWaiter and sends a reply - handleStateIdsRequests(t, result.Server, result.ServerRoom, result.fedStateIdsRequestReceivedWaiter, result.fedStateIdsSendResponseWaiter) + lastEvent := result.ServerRoom.Timeline[len(result.ServerRoom.Timeline)-1] + currentState := result.ServerRoom.AllCurrentState() + handleStateIdsRequests( + t, result.Server, result.ServerRoom, + lastEvent.EventID(), currentState, + result.fedStateIdsRequestReceivedWaiter, result.fedStateIdsSendResponseWaiter, + ) // a handler for /state requests, which sends a sensible response - handleStateRequests(t, result.Server, result.ServerRoom, nil, nil) + handleStateRequests( + t, result.Server, result.ServerRoom, + lastEvent.EventID(), currentState, + nil, nil, + ) // have joiningUser join the room by room ID. joiningUser.JoinRoom(t, result.ServerRoom.RoomID, []string{result.Server.ServerName()}) @@ -630,16 +641,20 @@ func (psj *partialStateJoinResult) FinishStateRequest() { psj.fedStateIdsSendResponseWaiter.Finish() } -// handleStateIdsRequests registers a handler for /state_ids requests for serverRoom. +// handleStateIdsRequests registers a handler for /state_ids requests for 'eventID' +// +// the returned state is as passed in 'roomState' // // if requestReceivedWaiter is not nil, it will be Finish()ed when the request arrives. // if sendResponseWaiter is not nil, we will Wait() for it to finish before sending the response. func handleStateIdsRequests( t *testing.T, srv *federation.Server, serverRoom *federation.ServerRoom, + eventID string, roomState []*gomatrixserverlib.Event, requestReceivedWaiter *Waiter, sendResponseWaiter *Waiter, ) { - srv.Mux().Handle( + srv.Mux().NewRoute().Methods("GET").Path( fmt.Sprintf("/_matrix/federation/v1/state_ids/%s", serverRoom.RoomID), + ).Queries("event_id", eventID).Handler( http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { queryParams := req.URL.Query() t.Logf("Incoming state_ids request for event %s in room %s", queryParams["event_id"], serverRoom.RoomID) @@ -652,8 +667,8 @@ func handleStateIdsRequests( t.Logf("Replying to /state_ids request") res := gomatrixserverlib.RespStateIDs{ - AuthEventIDs: eventIDsFromEvents(serverRoom.AuthChain()), - StateEventIDs: eventIDsFromEvents(serverRoom.AllCurrentState()), + AuthEventIDs: eventIDsFromEvents(serverRoom.AuthChainForEvents(roomState)), + StateEventIDs: eventIDsFromEvents(roomState), } w.WriteHeader(200) jsonb, _ := json.Marshal(res) @@ -662,19 +677,24 @@ func handleStateIdsRequests( t.Errorf("Error writing to request: %v", err) } }), - ).Methods("GET") + ) + t.Logf("Registered state_ids handler for event %s", eventID) } -// makeStateHandler returns a handler for /state requests for serverRoom. +// makeStateHandler returns a handler for /state requests for 'eventID' +// +// the returned state is as passed in 'roomState' // // if requestReceivedWaiter is not nil, it will be Finish()ed when the request arrives. // if sendResponseWaiter is not nil, we will Wait() for it to finish before sending the response. func handleStateRequests( t *testing.T, srv *federation.Server, serverRoom *federation.ServerRoom, + eventID string, roomState []*gomatrixserverlib.Event, requestReceivedWaiter *Waiter, sendResponseWaiter *Waiter, ) { - srv.Mux().Handle( + srv.Mux().NewRoute().Methods("GET").Path( fmt.Sprintf("/_matrix/federation/v1/state/%s", serverRoom.RoomID), + ).Queries("event_id", eventID).Handler( http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { queryParams := req.URL.Query() t.Logf("Incoming state request for event %s in room %s", queryParams["event_id"], serverRoom.RoomID) @@ -685,8 +705,8 @@ func handleStateRequests( sendResponseWaiter.Waitf(t, 60*time.Second, "Waiting for /state request") } res := gomatrixserverlib.RespState{ - AuthEvents: gomatrixserverlib.NewEventJSONsFromEvents(serverRoom.AuthChain()), - StateEvents: gomatrixserverlib.NewEventJSONsFromEvents(serverRoom.AllCurrentState()), + AuthEvents: gomatrixserverlib.NewEventJSONsFromEvents(serverRoom.AuthChainForEvents(roomState)), + StateEvents: gomatrixserverlib.NewEventJSONsFromEvents(roomState), } w.WriteHeader(200) jsonb, _ := json.Marshal(res) @@ -695,7 +715,7 @@ func handleStateRequests( t.Errorf("Error writing to request: %v", err) } }), - ).Methods("GET") + ) } func eventIDsFromEvents(he []*gomatrixserverlib.Event) []string {