Skip to content

Commit

Permalink
Faster joins tests: make request handlers more specific
Browse files Browse the repository at this point in the history
I'm going to add a test which will involve *multiple* /state and /state_ids
requests, so we need to make the registered handlers more selective.
  • Loading branch information
richvdh committed Jul 21, 2022
1 parent 4418d2c commit 687ad36
Showing 1 changed file with 33 additions and 13 deletions.
46 changes: 33 additions & 13 deletions tests/federation_room_join_partial_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -589,12 +589,23 @@ func beginPartialStateJoin(t *testing.T, deployment *docker.Deployment, joiningU
},
}))

// register a handler for /state_ids requests, which finishes fedStateIdsRequestReceivedWaiter, then
// register a handler for /state_ids requests for the most recent event,
// which finishes fedStateIdsRequestReceivedWaiter, then
// waits for fedStateIdsSendResponseWaiter and sends a reply
handleStateIdsRequests(t, result.Server, result.ServerRoom, result.fedStateIdsRequestReceivedWaiter, result.fedStateIdsSendResponseWaiter)
lastEvent := result.ServerRoom.Timeline[len(result.ServerRoom.Timeline)-1]
currentState := result.ServerRoom.AllCurrentState()
handleStateIdsRequests(
t, result.Server, result.ServerRoom,
lastEvent.EventID(), currentState,
result.fedStateIdsRequestReceivedWaiter, result.fedStateIdsSendResponseWaiter,
)

// a handler for /state requests, which sends a sensible response
handleStateRequests(t, result.Server, result.ServerRoom, nil, nil)
handleStateRequests(
t, result.Server, result.ServerRoom,
lastEvent.EventID(), currentState,
nil, nil,
)

// have joiningUser join the room by room ID.
joiningUser.JoinRoom(t, result.ServerRoom.RoomID, []string{result.Server.ServerName()})
Expand Down Expand Up @@ -630,16 +641,20 @@ func (psj *partialStateJoinResult) FinishStateRequest() {
psj.fedStateIdsSendResponseWaiter.Finish()
}

// handleStateIdsRequests registers a handler for /state_ids requests for serverRoom.
// handleStateIdsRequests registers a handler for /state_ids requests for 'eventID'
//
// the returned state is as passed in 'roomState'
//
// if requestReceivedWaiter is not nil, it will be Finish()ed when the request arrives.
// if sendResponseWaiter is not nil, we will Wait() for it to finish before sending the response.
func handleStateIdsRequests(
t *testing.T, srv *federation.Server, serverRoom *federation.ServerRoom,
eventID string, roomState []*gomatrixserverlib.Event,
requestReceivedWaiter *Waiter, sendResponseWaiter *Waiter,
) {
srv.Mux().Handle(
srv.Mux().NewRoute().Methods("GET").Path(
fmt.Sprintf("/_matrix/federation/v1/state_ids/%s", serverRoom.RoomID),
).Queries("event_id", eventID).Handler(
http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
queryParams := req.URL.Query()
t.Logf("Incoming state_ids request for event %s in room %s", queryParams["event_id"], serverRoom.RoomID)
Expand All @@ -652,8 +667,8 @@ func handleStateIdsRequests(
t.Logf("Replying to /state_ids request")

res := gomatrixserverlib.RespStateIDs{
AuthEventIDs: eventIDsFromEvents(serverRoom.AuthChain()),
StateEventIDs: eventIDsFromEvents(serverRoom.AllCurrentState()),
AuthEventIDs: eventIDsFromEvents(serverRoom.AuthChainForEvents(roomState)),
StateEventIDs: eventIDsFromEvents(roomState),
}
w.WriteHeader(200)
jsonb, _ := json.Marshal(res)
Expand All @@ -662,19 +677,24 @@ func handleStateIdsRequests(
t.Errorf("Error writing to request: %v", err)
}
}),
).Methods("GET")
)
t.Logf("Registered state_ids handler for event %s", eventID)
}

// makeStateHandler returns a handler for /state requests for serverRoom.
// makeStateHandler returns a handler for /state requests for 'eventID'
//
// the returned state is as passed in 'roomState'
//
// if requestReceivedWaiter is not nil, it will be Finish()ed when the request arrives.
// if sendResponseWaiter is not nil, we will Wait() for it to finish before sending the response.
func handleStateRequests(
t *testing.T, srv *federation.Server, serverRoom *federation.ServerRoom,
eventID string, roomState []*gomatrixserverlib.Event,
requestReceivedWaiter *Waiter, sendResponseWaiter *Waiter,
) {
srv.Mux().Handle(
srv.Mux().NewRoute().Methods("GET").Path(
fmt.Sprintf("/_matrix/federation/v1/state/%s", serverRoom.RoomID),
).Queries("event_id", eventID).Handler(
http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
queryParams := req.URL.Query()
t.Logf("Incoming state request for event %s in room %s", queryParams["event_id"], serverRoom.RoomID)
Expand All @@ -685,8 +705,8 @@ func handleStateRequests(
sendResponseWaiter.Waitf(t, 60*time.Second, "Waiting for /state request")
}
res := gomatrixserverlib.RespState{
AuthEvents: gomatrixserverlib.NewEventJSONsFromEvents(serverRoom.AuthChain()),
StateEvents: gomatrixserverlib.NewEventJSONsFromEvents(serverRoom.AllCurrentState()),
AuthEvents: gomatrixserverlib.NewEventJSONsFromEvents(serverRoom.AuthChainForEvents(roomState)),
StateEvents: gomatrixserverlib.NewEventJSONsFromEvents(roomState),
}
w.WriteHeader(200)
jsonb, _ := json.Marshal(res)
Expand All @@ -695,7 +715,7 @@ func handleStateRequests(
t.Errorf("Error writing to request: %v", err)
}
}),
).Methods("GET")
)
}

func eventIDsFromEvents(he []*gomatrixserverlib.Event) []string {
Expand Down

0 comments on commit 687ad36

Please sign in to comment.