From 93b1a2d0da1a8201974348628c4662265685871e Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Wed, 8 May 2019 18:33:45 -0700 Subject: [PATCH 01/23] Prevent multiple device list updates from breaking a batch send --- synapse/storage/devices.py | 42 +++++++++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index fd869b934c7b..82b1b4f424bd 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -97,20 +97,48 @@ def get_devices_by_remote(self, destination, from_stream_id): def _get_devices_by_remote_txn( self, txn, destination, from_stream_id, now_stream_id ): + # We retrieve n+1 devices from the list of outbound pokes were n is our + # maximum. We then check if the very last device has the same stream_id as the + # second-to-last device. If so, then we ignore all devices with that stream_id + # and only send the devices with a lower stream_id. + # + # If when culling the list we end up with no devices afterwards, we consider the + # device update to be too large, and simply skip the stream_id - the rationale + # being that such a large device list update is likely an error. + maximum_devices = 100 sql = """ SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? GROUP BY user_id, device_id - LIMIT 20 - """ + ORDER BY stream_id + LIMIT %d + """ % (maximum_devices + 1) txn.execute(sql, (destination, from_stream_id, now_stream_id, False)) - # maps (user_id, device_id) -> stream_id - query_map = {(r[0], r[1]): r[2] for r in txn} - if not query_map: - return (now_stream_id, []) + updates = [r for r in txn] - if len(query_map) >= 20: + # TODO: Does this actually do what we want it to do? + + # Check if the last and second-to-last row's stream_id's are the same + offending_stream_id = None + if ( + len(updates) > maximum_devices and + updates[-1][2] == updates[-2][2] + ): + offending_stream_id = updates[-1][2] + + # maps (user_id, device_id) -> stream_id + # as long as their stream_id does not match that of the last row + query_map = { + (r[0], r[1]): r[2] for r in updates + if r[2] is not offending_stream_id + } + + # If we ended up not being left over with any device updates to send + # out, then skip this stream_id + if len(query_map) == 0: + return (now_stream_id + 1, []) + elif len(query_map) >= maximum_devices: now_stream_id = max(stream_id for stream_id in itervalues(query_map)) devices = self._get_e2e_device_keys_txn( From 0ee2a8bf90e1fa275cfcee8ea8cb633274f9deb5 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Wed, 8 May 2019 18:37:43 -0700 Subject: [PATCH 02/23] Add changelog --- changelog.d/5156.bugfix | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/5156.bugfix diff --git a/changelog.d/5156.bugfix b/changelog.d/5156.bugfix new file mode 100644 index 000000000000..e8aa7d8241c1 --- /dev/null +++ b/changelog.d/5156.bugfix @@ -0,0 +1 @@ +Prevent federation device list updates breaking when processing multiple updates at once. \ No newline at end of file From 56cf3fb064063ce0ab0d63c0448b0f11f0a67719 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 9 May 2019 12:36:26 -0700 Subject: [PATCH 03/23] GROUP BY in python --- synapse/storage/devices.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 82b1b4f424bd..a5c1ba078843 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -105,19 +105,40 @@ def _get_devices_by_remote_txn( # If when culling the list we end up with no devices afterwards, we consider the # device update to be too large, and simply skip the stream_id - the rationale # being that such a large device list update is likely an error. + # + # Note: The code below assumes this value is at least 1 maximum_devices = 100 sql = """ - SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes + SELECT user_id, device_id, stream_id FROM device_lists_outbound_pokes WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? - GROUP BY user_id, device_id ORDER BY stream_id LIMIT %d """ % (maximum_devices + 1) txn.execute(sql, (destination, from_stream_id, now_stream_id, False)) - updates = [r for r in txn] + duplicate_updates = [r for r in txn] + + # Return if there are no updates to send out + if len(duplicate_updates) == 0: + return (now_stream_id, []) + + # Perform the equivalent of a GROUP BY + # Iterate through the updates list and copy any non-duplicate + # (user_id, device_id) entries + updates = [duplicate_updates[0]] + for i in range(1, len(duplicate_updates)): + update = duplicate_updates[i] + prev_update = duplicate_updates[i-1] + + if (update[0], update[1]) == (prev_update[0], prev_update[1]): + # This is a duplicate, don't copy it over + # However if its stream_id is higher, copy that to the new list + if update[3] > prev_update[3]: + updates[-1][3] = update[3] + continue - # TODO: Does this actually do what we want it to do? + # Not a duplicate, copy over + updates.append(update) # Check if the last and second-to-last row's stream_id's are the same offending_stream_id = None From a843676f155e7765d2ea5138d22bdc7c17e46003 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 9 May 2019 12:38:35 -0700 Subject: [PATCH 04/23] lint --- synapse/storage/devices.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index a5c1ba078843..7629ebc6718a 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -126,7 +126,7 @@ def _get_devices_by_remote_txn( # Iterate through the updates list and copy any non-duplicate # (user_id, device_id) entries updates = [duplicate_updates[0]] - for i in range(1, len(duplicate_updates)): + for i in range(1, len(duplicate_updates)): update = duplicate_updates[i] prev_update = duplicate_updates[i-1] From 80b6e1ae72122b10e6bae59409095be80f147ffe Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 9 May 2019 12:40:26 -0700 Subject: [PATCH 05/23] commit lint --- synapse/storage/devices.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 7629ebc6718a..3f2611479fa4 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -128,7 +128,7 @@ def _get_devices_by_remote_txn( updates = [duplicate_updates[0]] for i in range(1, len(duplicate_updates)): update = duplicate_updates[i] - prev_update = duplicate_updates[i-1] + prev_update = duplicate_updates[i - 1] if (update[0], update[1]) == (prev_update[0], prev_update[1]): # This is a duplicate, don't copy it over From c988c1e756f9c6268ae15b85c1a36d715d3c52ea Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Fri, 10 May 2019 10:14:29 -0700 Subject: [PATCH 06/23] WIP --- synapse/federation/sender/per_destination_queue.py | 2 +- synapse/storage/devices.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index be992110032b..c6382bcb0b3c 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -364,7 +364,7 @@ def _get_new_device_messages(self): last_device_list = self._last_device_list_stream_id now_stream_id, results = yield self._store.get_devices_by_remote( - self._destination, last_device_list + self._destination, last_device_list, MAX_EDUS_PER_TRANSACTION ) edus.extend( Edu( diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 3f2611479fa4..335ea5aadbf7 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -72,7 +72,7 @@ def get_devices_by_user(self, user_id): defer.returnValue({d["device_id"]: d for d in devices}) - def get_devices_by_remote(self, destination, from_stream_id): + def get_devices_by_remote(self, destination, from_stream_id, limit=100): """Get stream of updates to send to remote servers Returns: @@ -92,10 +92,11 @@ def get_devices_by_remote(self, destination, from_stream_id): destination, from_stream_id, now_stream_id, + limit, ) def _get_devices_by_remote_txn( - self, txn, destination, from_stream_id, now_stream_id + self, txn, destination, from_stream_id, now_stream_id, limit=100 ): # We retrieve n+1 devices from the list of outbound pokes were n is our # maximum. We then check if the very last device has the same stream_id as the @@ -107,13 +108,12 @@ def _get_devices_by_remote_txn( # being that such a large device list update is likely an error. # # Note: The code below assumes this value is at least 1 - maximum_devices = 100 sql = """ SELECT user_id, device_id, stream_id FROM device_lists_outbound_pokes WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? ORDER BY stream_id LIMIT %d - """ % (maximum_devices + 1) + """ % (limit + 1) txn.execute(sql, (destination, from_stream_id, now_stream_id, False)) duplicate_updates = [r for r in txn] @@ -143,7 +143,7 @@ def _get_devices_by_remote_txn( # Check if the last and second-to-last row's stream_id's are the same offending_stream_id = None if ( - len(updates) > maximum_devices and + len(updates) > limit and updates[-1][2] == updates[-2][2] ): offending_stream_id = updates[-1][2] @@ -159,7 +159,7 @@ def _get_devices_by_remote_txn( # out, then skip this stream_id if len(query_map) == 0: return (now_stream_id + 1, []) - elif len(query_map) >= maximum_devices: + elif len(query_map) >= limit : now_stream_id = max(stream_id for stream_id in itervalues(query_map)) devices = self._get_e2e_device_keys_txn( From 0cb7a60aa5417adee87398fa75e1e14850a94dde Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Fri, 10 May 2019 12:11:49 -0700 Subject: [PATCH 07/23] split _get_max_stream_id_for_devices_txn into 2 funcs --- synapse/storage/devices.py | 107 ++++++++++++++++++++----------------- 1 file changed, 58 insertions(+), 49 deletions(-) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index f2be77fb724d..34de3b386773 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -78,6 +78,9 @@ def get_devices_by_remote(self, destination, from_stream_id, limit=100): Returns: (int, list[dict]): current stream id and list of updates """ + if limit < 1: + raise StoreError("Device limit must be at least 1") + now_stream_id = self._device_list_id_gen.get_current_token() has_changed = self._device_list_federation_stream_cache.has_entity_changed( @@ -86,7 +89,7 @@ def get_devices_by_remote(self, destination, from_stream_id, limit=100): if not has_changed: return (now_stream_id, []) - return self.runInteraction( + updates = self.runInteraction( "get_devices_by_remote", self._get_devices_by_remote_txn, destination, @@ -95,51 +98,10 @@ def get_devices_by_remote(self, destination, from_stream_id, limit=100): limit, ) - def _get_devices_by_remote_txn( - self, txn, destination, from_stream_id, now_stream_id, limit - ): - # We retrieve n+1 devices from the list of outbound pokes were n is our - # maximum. We then check if the very last device has the same stream_id as the - # second-to-last device. If so, then we ignore all devices with that stream_id - # and only send the devices with a lower stream_id. - # - # If when culling the list we end up with no devices afterwards, we consider the - # device update to be too large, and simply skip the stream_id - the rationale - # being that such a large device list update is likely an error. - # - # Note: The code below assumes this value is at least 1 - sql = """ - SELECT user_id, device_id, stream_id FROM device_lists_outbound_pokes - WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? - ORDER BY stream_id - LIMIT %d - """ % (limit + 1) - txn.execute(sql, (destination, from_stream_id, now_stream_id, False)) - - duplicate_updates = list(txn) - - # Return if there are no updates to send out - if len(duplicate_updates) == 0: + # Return if there are no updates + if len(updates) == 0: return (now_stream_id, []) - # Perform the equivalent of a GROUP BY - # Iterate through the updates list and copy any non-duplicate - # (user_id, device_id) entries - updates = [duplicate_updates[0]] - for i in range(1, len(duplicate_updates)): - update = duplicate_updates[i] - prev_update = duplicate_updates[i - 1] - - if (update[0], update[1]) == (prev_update[0], prev_update[1]): - # This is a duplicate, don't copy it over - # However if its stream_id is higher, copy that to the new list - if update[3] > prev_update[3]: - updates[-1][3] = update[3] - continue - - # Not a duplicate, copy over - updates.append(update) - # Check if the last and second-to-last row's stream_id's are the same offending_stream_id = None if ( @@ -148,20 +110,67 @@ def _get_devices_by_remote_txn( ): offending_stream_id = updates[-1][2] + # Perform the equivalent of a GROUP BY + # Iterate through the updates list and copy non-duplicate + # (user_id, device_id) entries into a map, with the value being + # the max stream_id across each set of duplicate entries + # # maps (user_id, device_id) -> stream_id # as long as their stream_id does not match that of the last row - query_map = { - (r[0], r[1]): r[2] for r in updates - if r[2] is not offending_stream_id - } + query_map = {} + for update in updates: + if update[2] == offending_stream_id: + continue + + key = (update[0], update[1]) + if key in query_map and query_map[key] >= update[2]: + # Preserve larger stream_id + continue + + query_map[key] = update[2] # If we ended up not being left over with any device updates to send # out, then skip this stream_id if len(query_map) == 0: return (now_stream_id + 1, []) - elif len(query_map) >= limit : + elif len(query_map) >= limit: now_stream_id = max(stream_id for stream_id in itervalues(query_map)) + return self.runInteraction( + "_get_max_stream_id_for_devices_txn", + self._get_max_stream_id_for_devices_txn, + from_stream_id, + now_stream_id, + query_map, + limit, + ) + + def _get_devices_by_remote_txn( + self, txn, destination, from_stream_id, now_stream_id, limit + ): + # We retrieve n+1 devices from the list of outbound pokes were n is our + # maximum. In our parent function, we then check if the very last + # device has the same stream_id as the second-to-last device. If so, + # then we ignore all devices with that stream_id and only send the + # devices with a lower stream_id. + # + # If when culling the list we end up with no devices afterwards, we + # consider the device update to be too large, and simply skip the + # stream_id - the rationale being that such a large device list update + # is likely an error. + sql = """ + SELECT user_id, device_id, stream_id FROM device_lists_outbound_pokes + WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? + ORDER BY stream_id + LIMIT ? + """ + txn.execute(sql, (destination, from_stream_id, now_stream_id, False, limit + 1)) + + return list(txn) + + def _get_max_stream_id_for_devices_txn( + self, txn, destination, from_stream_id, now_stream_id, query_map, limit + ): devices = self._get_e2e_device_keys_txn( txn, query_map.keys(), From 2e5e32ee00797e9cc081d531ac24751e31153c0d Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Fri, 10 May 2019 19:05:10 -0700 Subject: [PATCH 08/23] yield deferreds --- synapse/storage/devices.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 34de3b386773..a42d8fbd7c7f 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -89,7 +89,7 @@ def get_devices_by_remote(self, destination, from_stream_id, limit=100): if not has_changed: return (now_stream_id, []) - updates = self.runInteraction( + updates = yield self.runInteraction( "get_devices_by_remote", self._get_devices_by_remote_txn, destination, From 768425901cb2123cef44f98bef5d31a2cd8e1db8 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Fri, 10 May 2019 19:24:41 -0700 Subject: [PATCH 09/23] deferred and missing argument --- synapse/storage/devices.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index a42d8fbd7c7f..f08faf25aa27 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -72,6 +72,7 @@ def get_devices_by_user(self, user_id): defer.returnValue({d["device_id"]: d for d in devices}) + @defer.inlineCallbacks def get_devices_by_remote(self, destination, from_stream_id, limit=100): """Get stream of updates to send to remote servers @@ -100,7 +101,7 @@ def get_devices_by_remote(self, destination, from_stream_id, limit=100): # Return if there are no updates if len(updates) == 0: - return (now_stream_id, []) + defer.returnValue((now_stream_id, [])) # Check if the last and second-to-last row's stream_id's are the same offending_stream_id = None @@ -136,15 +137,18 @@ def get_devices_by_remote(self, destination, from_stream_id, limit=100): elif len(query_map) >= limit: now_stream_id = max(stream_id for stream_id in itervalues(query_map)) - return self.runInteraction( + max_stream_id_and_results = yield self.runInteraction( "_get_max_stream_id_for_devices_txn", self._get_max_stream_id_for_devices_txn, + destination, from_stream_id, now_stream_id, query_map, limit, ) + defer.returnValue(max_stream_id_and_results) + def _get_devices_by_remote_txn( self, txn, destination, from_stream_id, now_stream_id, limit ): From d9078b60e1581eee9c6bf2e0a3ba8295e07a51a0 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Fri, 10 May 2019 19:30:21 -0700 Subject: [PATCH 10/23] missed one --- synapse/storage/devices.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index f08faf25aa27..86b29541b0aa 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -88,7 +88,7 @@ def get_devices_by_remote(self, destination, from_stream_id, limit=100): destination, int(from_stream_id) ) if not has_changed: - return (now_stream_id, []) + defer.returnValue((now_stream_id, [])) updates = yield self.runInteraction( "get_devices_by_remote", From a674d8c728d131cfc10dbc69482800f15e81fcd1 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Fri, 10 May 2019 19:32:37 -0700 Subject: [PATCH 11/23] and another --- synapse/storage/devices.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 86b29541b0aa..c11ff589d9e1 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -133,7 +133,7 @@ def get_devices_by_remote(self, destination, from_stream_id, limit=100): # If we ended up not being left over with any device updates to send # out, then skip this stream_id if len(query_map) == 0: - return (now_stream_id + 1, []) + defer.returnValue((now_stream_id + 1, [])) elif len(query_map) >= limit: now_stream_id = max(stream_id for stream_id in itervalues(query_map)) From 84db73daf3de6e6c1b634c6d12810b0eec7587a9 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Tue, 21 May 2019 11:49:11 +0100 Subject: [PATCH 12/23] Address some comments and clean things up --- .../sender/per_destination_queue.py | 7 +- synapse/storage/devices.py | 67 ++++++++++--------- tests/storage/test_devices.py | 4 ++ 3 files changed, 42 insertions(+), 36 deletions(-) diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 8fdc1d2a5edc..f4d39411d30a 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -25,12 +25,10 @@ HttpResponseException, RequestSendFailed, ) -from synapse.events import EventBase from synapse.federation.units import Edu from synapse.handlers.presence import format_user_presence_state from synapse.metrics import sent_transactions_counter from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage import UserPresenceState from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter # This is defined in the Matrix spec and enforced by the receiver. @@ -349,9 +347,10 @@ def _pop_pending_edus(self, limit): @defer.inlineCallbacks def _get_new_device_messages(self, limit): last_device_list = self._last_device_list_stream_id - # Will return at most 20 entries + + # Retrieve list of new device updates to send to the destination now_stream_id, results = yield self._store.get_devices_by_remote( - self._destination, last_device_list, limit=limit - 1, + self._destination, last_device_list, limit=limit, ) edus = [ Edu( diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index c11ff589d9e1..82a50c019d7d 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -80,7 +80,7 @@ def get_devices_by_remote(self, destination, from_stream_id, limit=100): (int, list[dict]): current stream id and list of updates """ if limit < 1: - raise StoreError("Device limit must be at least 1") + raise RuntimeError("Device limit must be at least 1") now_stream_id = self._device_list_id_gen.get_current_token() @@ -90,28 +90,38 @@ def get_devices_by_remote(self, destination, from_stream_id, limit=100): if not has_changed: defer.returnValue((now_stream_id, [])) + # We retrieve n+1 devices from the list of outbound pokes where n is + # our outbound device update limit. We then check if the very last + # device has the same stream_id as the second-to-last device. If so, + # then we ignore all devices with that stream_id and only send the + # devices with a lower stream_id. + # + # If when culling the list we end up with no devices afterwards, we + # consider the device update to be too large, and simply skip the + # stream_id; the rationale being that such a large device list update + # is likely an error. updates = yield self.runInteraction( "get_devices_by_remote", self._get_devices_by_remote_txn, destination, from_stream_id, now_stream_id, - limit, + limit + 1, ) - # Return if there are no updates - if len(updates) == 0: + # Return an empty list if there are no updates + if not updates: defer.returnValue((now_stream_id, [])) # Check if the last and second-to-last row's stream_id's are the same - offending_stream_id = None if ( len(updates) > limit and updates[-1][2] == updates[-2][2] ): - offending_stream_id = updates[-1][2] + now_stream_id = updates[-1][2] # Perform the equivalent of a GROUP BY + # # Iterate through the updates list and copy non-duplicate # (user_id, device_id) entries into a map, with the value being # the max stream_id across each set of duplicate entries @@ -120,26 +130,27 @@ def get_devices_by_remote(self, destination, from_stream_id, limit=100): # as long as their stream_id does not match that of the last row query_map = {} for update in updates: - if update[2] == offending_stream_id: - continue + if update[2] == now_stream_id: + # Stop processing updates + break key = (update[0], update[1]) - if key in query_map and query_map[key] >= update[2]: - # Preserve larger stream_id - continue - - query_map[key] = update[2] + query_map[key] = max(query_map.get(key, 0), update[2]) # If we ended up not being left over with any device updates to send - # out, then skip this stream_id - if len(query_map) == 0: + # out, then skip this stream_id. + # + # The list of updates associated with this stream_id is too large and + # thus we're just going to assume it was a client-side error and not + # send them. We return an empty list of updates instead. + if not query_map: defer.returnValue((now_stream_id + 1, [])) elif len(query_map) >= limit: now_stream_id = max(stream_id for stream_id in itervalues(query_map)) - max_stream_id_and_results = yield self.runInteraction( - "_get_max_stream_id_for_devices_txn", - self._get_max_stream_id_for_devices_txn, + results = yield self.runInteraction( + "_get_devices_txn", + self._get_devices_txn, destination, from_stream_id, now_stream_id, @@ -147,34 +158,26 @@ def get_devices_by_remote(self, destination, from_stream_id, limit=100): limit, ) - defer.returnValue(max_stream_id_and_results) + defer.returnValue((now_stream_id, results)) def _get_devices_by_remote_txn( self, txn, destination, from_stream_id, now_stream_id, limit ): - # We retrieve n+1 devices from the list of outbound pokes were n is our - # maximum. In our parent function, we then check if the very last - # device has the same stream_id as the second-to-last device. If so, - # then we ignore all devices with that stream_id and only send the - # devices with a lower stream_id. - # - # If when culling the list we end up with no devices afterwards, we - # consider the device update to be too large, and simply skip the - # stream_id - the rationale being that such a large device list update - # is likely an error. + """Return device update information for a given remote destination""" sql = """ SELECT user_id, device_id, stream_id FROM device_lists_outbound_pokes WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? ORDER BY stream_id LIMIT ? """ - txn.execute(sql, (destination, from_stream_id, now_stream_id, False, limit + 1)) + txn.execute(sql, (destination, from_stream_id, now_stream_id, False, limit)) return list(txn) - def _get_max_stream_id_for_devices_txn( + def _get_device_update_edus_by_remote_txn( self, txn, destination, from_stream_id, now_stream_id, query_map, limit ): + """Returns a list of device update EDUs as well as E2EE keys""" devices = self._get_e2e_device_keys_txn( txn, query_map.keys(), @@ -218,7 +221,7 @@ def _get_max_stream_id_for_devices_txn( results.append(result) - return (now_stream_id, results) + return results def mark_as_sent_devices_by_remote(self, destination, stream_id): """Mark that updates have successfully been sent to the destination. diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index aef4dfaf57a0..d2e17403a7b3 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -71,6 +71,10 @@ def test_get_devices_by_user(self): res["device2"], ) + @defer.inlineCallbacks + def test_get_devices_by_remote(self): + self.store.store_device() + @defer.inlineCallbacks def test_update_device(self): yield self.store.store_device("user_id", "device_id", "display_name 1") From cf7734364f7a8e042a294196634f2c140daaa0c8 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Tue, 21 May 2019 13:58:04 +0100 Subject: [PATCH 13/23] lint --- synapse/federation/sender/per_destination_queue.py | 2 ++ synapse/server.pyi | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index f4d39411d30a..564c57203d33 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -25,10 +25,12 @@ HttpResponseException, RequestSendFailed, ) +from synapse.events import EventBase from synapse.federation.units import Edu from synapse.handlers.presence import format_user_presence_state from synapse.metrics import sent_transactions_counter from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage import UserPresenceState from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter # This is defined in the Matrix spec and enforced by the receiver. diff --git a/synapse/server.pyi b/synapse/server.pyi index 3ba3a967c2a9..9583e82d5213 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -18,7 +18,6 @@ import synapse.server_notices.server_notices_sender import synapse.state import synapse.storage - class HomeServer(object): @property def config(self) -> synapse.config.homeserver.HomeServerConfig: From fcda6071d0726add2706339a3c27099828e9c292 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Tue, 21 May 2019 15:49:13 +0100 Subject: [PATCH 14/23] test progress --- synapse/storage/devices.py | 17 ++++++++--------- tests/storage/test_devices.py | 28 +++++++++++++++++++++++++++- tests/utils.py | 2 +- 3 files changed, 36 insertions(+), 11 deletions(-) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 82a50c019d7d..616e7e92cb91 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from six import iteritems, itervalues +from six import iteritems from canonicaljson import json @@ -79,11 +79,9 @@ def get_devices_by_remote(self, destination, from_stream_id, limit=100): Returns: (int, list[dict]): current stream id and list of updates """ - if limit < 1: - raise RuntimeError("Device limit must be at least 1") - now_stream_id = self._device_list_id_gen.get_current_token() + # Why is this False in the test? has_changed = self._device_list_federation_stream_cache.has_entity_changed( destination, int(from_stream_id) ) @@ -115,9 +113,11 @@ def get_devices_by_remote(self, destination, from_stream_id, limit=100): # Check if the last and second-to-last row's stream_id's are the same if ( + len(updates) > 1 and len(updates) > limit and updates[-1][2] == updates[-2][2] ): + # If so, cap our maximum stream_id at that final stream_id now_stream_id = updates[-1][2] # Perform the equivalent of a GROUP BY @@ -138,19 +138,18 @@ def get_devices_by_remote(self, destination, from_stream_id, limit=100): query_map[key] = max(query_map.get(key, 0), update[2]) # If we ended up not being left over with any device updates to send - # out, then skip this stream_id. + # out (because there was more device updates with the same stream_id + # that our defined limit allows), then just skip this stream_id. # # The list of updates associated with this stream_id is too large and # thus we're just going to assume it was a client-side error and not # send them. We return an empty list of updates instead. if not query_map: defer.returnValue((now_stream_id + 1, [])) - elif len(query_map) >= limit: - now_stream_id = max(stream_id for stream_id in itervalues(query_map)) results = yield self.runInteraction( - "_get_devices_txn", - self._get_devices_txn, + "_get_device_update_edus_by_remote_txn", + self._get_device_update_edus_by_remote_txn, destination, from_stream_id, now_stream_id, diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index d2e17403a7b3..061b15e34787 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer import synapse.api.errors @@ -20,6 +22,8 @@ import tests.unittest import tests.utils +logger = logging.getLogger(__name__) + class DeviceStoreTestCase(tests.unittest.TestCase): def __init__(self, *args, **kwargs): @@ -73,7 +77,29 @@ def test_get_devices_by_user(self): @defer.inlineCallbacks def test_get_devices_by_remote(self): - self.store.store_device() + device_ids = ["device_id1", "device_id2"] + + # Add a device update to the stream + stream_id = yield self.store.add_device_change_to_streams( + "user_id", device_ids, ["somehost"], + ) + + res = yield self.store.get_devices_by_remote("somehost", 0, limit=100) + + logger.info("Res: %s", res) + self.assertEqual(1, 2) + + device_updates = res[1] + + for update in device_updates: + d_id = update["device_id"] + if d_id in device_ids: + del device_ids[d_id] + + logger.info("stream_id: %s, updates: %s", stream_id, res) + + # All device_ids should've been accounted for + self.assertEqual(len(device_ids), 0) @defer.inlineCallbacks def test_update_device(self): diff --git a/tests/utils.py b/tests/utils.py index c2ef4b0bb580..56bd80d6136a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -247,7 +247,7 @@ def setup_test_homeserver( else: config.database_config = { "name": "sqlite3", - "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1}, + "args": {"database": "test.db", "cp_min": 1, "cp_max": 1}, } db_engine = create_engine(config.database_config) From 69c0c1b591944bac81547886161d77b194c3b6f6 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Wed, 22 May 2019 18:01:41 +0100 Subject: [PATCH 15/23] fix test --- synapse/storage/devices.py | 7 ++++--- tests/storage/test_devices.py | 9 ++------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 616e7e92cb91..fffa8f17e05b 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -81,7 +81,6 @@ def get_devices_by_remote(self, destination, from_stream_id, limit=100): """ now_stream_id = self._device_list_id_gen.get_current_token() - # Why is this False in the test? has_changed = self._device_list_federation_stream_cache.has_entity_changed( destination, int(from_stream_id) ) @@ -111,6 +110,8 @@ def get_devices_by_remote(self, destination, from_stream_id, limit=100): if not updates: defer.returnValue((now_stream_id, [])) + stream_id_cutoff = now_stream_id + 1 + # Check if the last and second-to-last row's stream_id's are the same if ( len(updates) > 1 and @@ -118,7 +119,7 @@ def get_devices_by_remote(self, destination, from_stream_id, limit=100): updates[-1][2] == updates[-2][2] ): # If so, cap our maximum stream_id at that final stream_id - now_stream_id = updates[-1][2] + stream_id_cutoff = updates[-1][2] # Perform the equivalent of a GROUP BY # @@ -130,7 +131,7 @@ def get_devices_by_remote(self, destination, from_stream_id, limit=100): # as long as their stream_id does not match that of the last row query_map = {} for update in updates: - if update[2] == now_stream_id: + if update[2] >= stream_id_cutoff: # Stop processing updates break diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 061b15e34787..230f3f473328 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -84,19 +84,14 @@ def test_get_devices_by_remote(self): "user_id", device_ids, ["somehost"], ) - res = yield self.store.get_devices_by_remote("somehost", 0, limit=100) - - logger.info("Res: %s", res) - self.assertEqual(1, 2) + res = yield self.store.get_devices_by_remote("somehost", -1, limit=100) device_updates = res[1] for update in device_updates: d_id = update["device_id"] if d_id in device_ids: - del device_ids[d_id] - - logger.info("stream_id: %s, updates: %s", stream_id, res) + device_ids.remove(d_id) # All device_ids should've been accounted for self.assertEqual(len(device_ids), 0) From 06fa759a92bb3ef101804b270735a291d97b05d7 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 23 May 2019 10:02:36 +0100 Subject: [PATCH 16/23] lint --- tests/storage/test_devices.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 230f3f473328..3ff51e5f026c 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -80,10 +80,11 @@ def test_get_devices_by_remote(self): device_ids = ["device_id1", "device_id2"] # Add a device update to the stream - stream_id = yield self.store.add_device_change_to_streams( + yield self.store.add_device_change_to_streams( "user_id", device_ids, ["somehost"], ) + # Get all device updates ever meant for this remote res = yield self.store.get_devices_by_remote("somehost", -1, limit=100) device_updates = res[1] From 5c7bb2cfa1f35be453600d1e032b4f78fc4cd8b3 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 23 May 2019 11:16:28 +0100 Subject: [PATCH 17/23] Don't break buildkite --- tests/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils.py b/tests/utils.py index 56bd80d6136a..c2ef4b0bb580 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -247,7 +247,7 @@ def setup_test_homeserver( else: config.database_config = { "name": "sqlite3", - "args": {"database": "test.db", "cp_min": 1, "cp_max": 1}, + "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1}, } db_engine = create_engine(config.database_config) From c674c952629a88cdca021abc1879a8f788d24b81 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 23 May 2019 11:19:19 +0100 Subject: [PATCH 18/23] logging not needed --- tests/storage/test_devices.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 3ff51e5f026c..6a4a37ce8b00 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging - from twisted.internet import defer import synapse.api.errors @@ -22,8 +20,6 @@ import tests.unittest import tests.utils -logger = logging.getLogger(__name__) - class DeviceStoreTestCase(tests.unittest.TestCase): def __init__(self, *args, **kwargs): From 3dbb5f013930a7bff87a61fd3a3aafd7c3c5f32a Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 23 May 2019 11:19:59 +0100 Subject: [PATCH 19/23] unnecessary line removal --- synapse/server.pyi | 1 + 1 file changed, 1 insertion(+) diff --git a/synapse/server.pyi b/synapse/server.pyi index 9583e82d5213..3ba3a967c2a9 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -18,6 +18,7 @@ import synapse.server_notices.server_notices_sender import synapse.state import synapse.storage + class HomeServer(object): @property def config(self) -> synapse.config.homeserver.HomeServerConfig: From 322e1a39a76075f2d4b2f47b246c1298a656db72 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 23 May 2019 11:53:38 +0100 Subject: [PATCH 20/23] ok isort --- synapse/server.pyi | 1 - 1 file changed, 1 deletion(-) diff --git a/synapse/server.pyi b/synapse/server.pyi index 3ba3a967c2a9..9583e82d5213 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -18,7 +18,6 @@ import synapse.server_notices.server_notices_sender import synapse.state import synapse.storage - class HomeServer(object): @property def config(self) -> synapse.config.homeserver.HomeServerConfig: From da6a2ad6d4bad6189ff345b47d63834fbf17323e Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Mon, 3 Jun 2019 15:33:53 +0100 Subject: [PATCH 21/23] address review comments. add more tests --- synapse/storage/devices.py | 105 ++++++++++++++++++++++------------ tests/storage/test_devices.py | 59 ++++++++++++++++++- 2 files changed, 123 insertions(+), 41 deletions(-) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index fffa8f17e05b..edc3f1701005 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -73,7 +73,7 @@ def get_devices_by_user(self, user_id): defer.returnValue({d["device_id"]: d for d in devices}) @defer.inlineCallbacks - def get_devices_by_remote(self, destination, from_stream_id, limit=100): + def get_devices_by_remote(self, destination, from_stream_id, limit): """Get stream of updates to send to remote servers Returns: @@ -87,6 +87,8 @@ def get_devices_by_remote(self, destination, from_stream_id, limit=100): if not has_changed: defer.returnValue((now_stream_id, [])) + logger.debug("Getting from %d to %d", from_stream_id, now_stream_id) + # We retrieve n+1 devices from the list of outbound pokes where n is # our outbound device update limit. We then check if the very last # device has the same stream_id as the second-to-last device. If so, @@ -112,12 +114,8 @@ def get_devices_by_remote(self, destination, from_stream_id, limit=100): stream_id_cutoff = now_stream_id + 1 - # Check if the last and second-to-last row's stream_id's are the same - if ( - len(updates) > 1 and - len(updates) > limit and - updates[-1][2] == updates[-2][2] - ): + # Check if the last and second-to-last rows' stream_id's are the same + if len(updates) > limit: # If so, cap our maximum stream_id at that final stream_id stream_id_cutoff = updates[-1][2] @@ -138,24 +136,21 @@ def get_devices_by_remote(self, destination, from_stream_id, limit=100): key = (update[0], update[1]) query_map[key] = max(query_map.get(key, 0), update[2]) - # If we ended up not being left over with any device updates to send - # out (because there was more device updates with the same stream_id - # that our defined limit allows), then just skip this stream_id. - # - # The list of updates associated with this stream_id is too large and - # thus we're just going to assume it was a client-side error and not - # send them. We return an empty list of updates instead. + # If we didn't find any updates with a stream_id lower than the cutoff, it + # means that there are more than limit updates all of which have the same + # steam_id. + + # That should only happen if a client is spamming the server with new + # devices, in which case E2E isn't going to work well anyway. We'll just + # skip that stream_id and return an empty list, and continue with the next + # stream_id next time. if not query_map: - defer.returnValue((now_stream_id + 1, [])) + defer.returnValue((stream_id_cutoff, [])) - results = yield self.runInteraction( - "_get_device_update_edus_by_remote_txn", - self._get_device_update_edus_by_remote_txn, + results = yield self._get_device_update_edus_by_remote( destination, from_stream_id, - now_stream_id, query_map, - limit, ) defer.returnValue((now_stream_id, results)) @@ -163,7 +158,18 @@ def get_devices_by_remote(self, destination, from_stream_id, limit=100): def _get_devices_by_remote_txn( self, txn, destination, from_stream_id, now_stream_id, limit ): - """Return device update information for a given remote destination""" + """Return device update information for a given remote destination + + Args: + txn (LoggingTransaction): The transaction to execute + destination (str): The host the device updates are intended for + from_stream_id (int): The minimum stream_id to filter updates by, exclusive + now_stream_id (int): The maximum stream_id to filter updates by, inclusive + limit (int): Maximum number of device updates to return + + Returns: + List: List of device updates + """ sql = """ SELECT user_id, device_id, stream_id FROM device_lists_outbound_pokes WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? @@ -174,30 +180,42 @@ def _get_devices_by_remote_txn( return list(txn) - def _get_device_update_edus_by_remote_txn( - self, txn, destination, from_stream_id, now_stream_id, query_map, limit + @defer.inlineCallbacks + def _get_device_update_edus_by_remote( + self, destination, from_stream_id, query_map, ): - """Returns a list of device update EDUs as well as E2EE keys""" - devices = self._get_e2e_device_keys_txn( - txn, - query_map.keys(), - include_all_devices=True, - include_deleted_devices=True, - ) + """Returns a list of device update EDUs as well as E2EE keys + + Args: + destination (str): The host the device updates are intended for + from_stream_id (int): The minimum stream_id to filter updates by, exclusive + query_map (Dict[(str, str): int]): Dictionary mapping + user_id/device_id to update stream_id + + Returns: + List[Dict]: List of objects representing an device update EDU - prev_sent_id_sql = """ - SELECT coalesce(max(stream_id), 0) as stream_id - FROM device_lists_outbound_last_success - WHERE destination = ? AND user_id = ? AND stream_id <= ? """ + devices = yield self.runInteraction( + "_get_e2e_device_keys_txn", + self._get_e2e_device_keys_txn, + query_map.keys(), + True, + True, + ) results = [] for user_id, user_devices in iteritems(devices): # The prev_id for the first row is always the last row before # `from_stream_id` - txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id)) - rows = txn.fetchall() - prev_id = rows[0][0] + update_edus = yield self.runInteraction( + "_get_device_update_edus_by_remote_txn", + self._get_device_update_edus_by_remote_txn, + destination, + user_id, + from_stream_id, + ) + prev_id = update_edus[0][0] for device_id, device in iteritems(user_devices): stream_id = query_map[(user_id, device_id)] result = { @@ -221,7 +239,18 @@ def _get_device_update_edus_by_remote_txn( results.append(result) - return results + defer.returnValue(results) + + def _get_device_update_edus_by_remote_txn( + self, txn, destination, user_id, from_stream_id, + ): + prev_sent_id_sql = """ + SELECT coalesce(max(stream_id), 0) as stream_id + FROM device_lists_outbound_last_success + WHERE destination = ? AND user_id = ? AND stream_id <= ? + """ + txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id)) + return txn.fetchall() def mark_as_sent_devices_by_remote(self, destination, stream_id): """Mark that updates have successfully been sent to the destination. diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 6a4a37ce8b00..66de97f2c664 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -75,15 +75,68 @@ def test_get_devices_by_user(self): def test_get_devices_by_remote(self): device_ids = ["device_id1", "device_id2"] - # Add a device update to the stream + # Add two device updates with a single stream_id yield self.store.add_device_change_to_streams( "user_id", device_ids, ["somehost"], ) # Get all device updates ever meant for this remote - res = yield self.store.get_devices_by_remote("somehost", -1, limit=100) + now_stream_id, device_updates = yield self.store.get_devices_by_remote( + "somehost", -1, limit=100, + ) + + # Check original device_ids are contained within these updates + self._check_devices_in_updates(device_ids, device_updates) + + # Test breaking the update limit in 1, 101, and 1 device_id segments + # First test adding an update with 1 device + device_ids = ["device_id0"] + yield self.store.add_device_change_to_streams( + "user_id", device_ids, ["someotherhost"], + ) + + # Get all device updates ever meant for this remote + now_stream_id, device_updates = yield self.store.get_devices_by_remote( + "someotherhost", now_stream_id, limit=100, + ) + + # Check we got a single device update + self._check_devices_in_updates(device_ids, device_updates) + + # Try adding 101 updates (we expect to get an empty list back as it + # broke the limit) + device_ids = ["device_id" + str(i + 1) for i in range(101)] + + yield self.store.add_device_change_to_streams( + "user_id", device_ids, ["someotherhost"], + ) + + # Get all device updates meant for this remote. + now_stream_id, device_updates = yield self.store.get_devices_by_remote( + "someotherhost", now_stream_id, limit=100, + ) + + # We should get an empty list back as this broke the limit + self.assertEqual(len(device_updates), 0) + + # Try to insert one more device update. The 101 devices should've been cleared, + # so we should now just get one device update: this new one + device_ids = ["newdevice"] + yield self.store.add_device_change_to_streams( + "user_id", device_ids, ["someotherhost"], + ) + + # Get all device updates meant for this remote. + now_stream_id, device_updates = yield self.store.get_devices_by_remote( + "someotherhost", now_stream_id, limit=100, + ) + + # We should just get our one device update + self._check_devices_in_updates(device_ids, device_updates) - device_updates = res[1] + def _check_devices_in_updates(self, device_ids, device_updates): + """Check that an specific device ids exist in a list of device update EDUs""" + self.assertEqual(len(device_updates), len(device_ids)) for update in device_updates: d_id = update["device_id"] From 2231131563fdfe416cce883c18e2a753aef9cc9d Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Mon, 3 Jun 2019 16:29:43 +0100 Subject: [PATCH 22/23] Remove debug logging --- synapse/storage/devices.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index edc3f1701005..295e7acdca3d 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -87,8 +87,6 @@ def get_devices_by_remote(self, destination, from_stream_id, limit): if not has_changed: defer.returnValue((now_stream_id, [])) - logger.debug("Getting from %d to %d", from_stream_id, now_stream_id) - # We retrieve n+1 devices from the list of outbound pokes where n is # our outbound device update limit. We then check if the very last # device has the same stream_id as the second-to-last device. If so, From 0de7b1720153cc2a739f7bfda91e5068cac6b219 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 6 Jun 2019 23:24:59 +0100 Subject: [PATCH 23/23] minor tweaks --- synapse/storage/devices.py | 50 +++++++++++++------------- tests/storage/test_devices.py | 66 ++++++++++++++++------------------- 2 files changed, 56 insertions(+), 60 deletions(-) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 295e7acdca3d..d102e07372cc 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -77,7 +77,9 @@ def get_devices_by_remote(self, destination, from_stream_id, limit): """Get stream of updates to send to remote servers Returns: - (int, list[dict]): current stream id and list of updates + Deferred[tuple[int, list[dict]]]: + current stream id (ie, the stream id of the last update included in the + response), and the list of updates """ now_stream_id = self._device_list_id_gen.get_current_token() @@ -110,12 +112,13 @@ def get_devices_by_remote(self, destination, from_stream_id, limit): if not updates: defer.returnValue((now_stream_id, [])) - stream_id_cutoff = now_stream_id + 1 - - # Check if the last and second-to-last rows' stream_id's are the same + # if we have exceeded the limit, we need to exclude any results with the + # same stream_id as the last row. if len(updates) > limit: - # If so, cap our maximum stream_id at that final stream_id stream_id_cutoff = updates[-1][2] + now_stream_id = stream_id_cutoff - 1 + else: + stream_id_cutoff = None # Perform the equivalent of a GROUP BY # @@ -127,7 +130,7 @@ def get_devices_by_remote(self, destination, from_stream_id, limit): # as long as their stream_id does not match that of the last row query_map = {} for update in updates: - if update[2] >= stream_id_cutoff: + if stream_id_cutoff is not None and update[2] >= stream_id_cutoff: # Stop processing updates break @@ -198,22 +201,17 @@ def _get_device_update_edus_by_remote( "_get_e2e_device_keys_txn", self._get_e2e_device_keys_txn, query_map.keys(), - True, - True, + include_all_devices=True, + include_deleted_devices=True, ) results = [] for user_id, user_devices in iteritems(devices): # The prev_id for the first row is always the last row before # `from_stream_id` - update_edus = yield self.runInteraction( - "_get_device_update_edus_by_remote_txn", - self._get_device_update_edus_by_remote_txn, - destination, - user_id, - from_stream_id, + prev_id = yield self._get_last_device_update_for_remote_user( + destination, user_id, from_stream_id, ) - prev_id = update_edus[0][0] for device_id, device in iteritems(user_devices): stream_id = query_map[(user_id, device_id)] result = { @@ -239,16 +237,20 @@ def _get_device_update_edus_by_remote( defer.returnValue(results) - def _get_device_update_edus_by_remote_txn( - self, txn, destination, user_id, from_stream_id, + def _get_last_device_update_for_remote_user( + self, destination, user_id, from_stream_id, ): - prev_sent_id_sql = """ - SELECT coalesce(max(stream_id), 0) as stream_id - FROM device_lists_outbound_last_success - WHERE destination = ? AND user_id = ? AND stream_id <= ? - """ - txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id)) - return txn.fetchall() + def f(txn): + prev_sent_id_sql = """ + SELECT coalesce(max(stream_id), 0) as stream_id + FROM device_lists_outbound_last_success + WHERE destination = ? AND user_id = ? AND stream_id <= ? + """ + txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id)) + rows = txn.fetchall() + return rows[0][0] + + return self.runInteraction("get_last_device_update_for_remote_user", f) def mark_as_sent_devices_by_remote(self, destination, stream_id): """Mark that updates have successfully been sent to the destination. diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 66de97f2c664..6396ccddb52b 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -88,63 +88,57 @@ def test_get_devices_by_remote(self): # Check original device_ids are contained within these updates self._check_devices_in_updates(device_ids, device_updates) + @defer.inlineCallbacks + def test_get_devices_by_remote_limited(self): # Test breaking the update limit in 1, 101, and 1 device_id segments - # First test adding an update with 1 device - device_ids = ["device_id0"] + + # first add one device + device_ids1 = ["device_id0"] yield self.store.add_device_change_to_streams( - "user_id", device_ids, ["someotherhost"], + "user_id", device_ids1, ["someotherhost"], ) - # Get all device updates ever meant for this remote - now_stream_id, device_updates = yield self.store.get_devices_by_remote( - "someotherhost", now_stream_id, limit=100, + # then add 101 + device_ids2 = ["device_id" + str(i + 1) for i in range(101)] + yield self.store.add_device_change_to_streams( + "user_id", device_ids2, ["someotherhost"], ) - # Check we got a single device update - self._check_devices_in_updates(device_ids, device_updates) + # then one more + device_ids3 = ["newdevice"] + yield self.store.add_device_change_to_streams( + "user_id", device_ids3, ["someotherhost"], + ) - # Try adding 101 updates (we expect to get an empty list back as it - # broke the limit) - device_ids = ["device_id" + str(i + 1) for i in range(101)] + # + # now read them back. + # - yield self.store.add_device_change_to_streams( - "user_id", device_ids, ["someotherhost"], + # first we should get a single update + now_stream_id, device_updates = yield self.store.get_devices_by_remote( + "someotherhost", -1, limit=100, ) + self._check_devices_in_updates(device_ids1, device_updates) - # Get all device updates meant for this remote. + # Then we should get an empty list back as the 101 devices broke the limit now_stream_id, device_updates = yield self.store.get_devices_by_remote( "someotherhost", now_stream_id, limit=100, ) - - # We should get an empty list back as this broke the limit self.assertEqual(len(device_updates), 0) - # Try to insert one more device update. The 101 devices should've been cleared, - # so we should now just get one device update: this new one - device_ids = ["newdevice"] - yield self.store.add_device_change_to_streams( - "user_id", device_ids, ["someotherhost"], - ) - - # Get all device updates meant for this remote. + # The 101 devices should've been cleared, so we should now just get one device + # update now_stream_id, device_updates = yield self.store.get_devices_by_remote( "someotherhost", now_stream_id, limit=100, ) + self._check_devices_in_updates(device_ids3, device_updates) - # We should just get our one device update - self._check_devices_in_updates(device_ids, device_updates) - - def _check_devices_in_updates(self, device_ids, device_updates): + def _check_devices_in_updates(self, expected_device_ids, device_updates): """Check that an specific device ids exist in a list of device update EDUs""" - self.assertEqual(len(device_updates), len(device_ids)) - - for update in device_updates: - d_id = update["device_id"] - if d_id in device_ids: - device_ids.remove(d_id) + self.assertEqual(len(device_updates), len(expected_device_ids)) - # All device_ids should've been accounted for - self.assertEqual(len(device_ids), 0) + received_device_ids = {update["device_id"] for update in device_updates} + self.assertEqual(received_device_ids, set(expected_device_ids)) @defer.inlineCallbacks def test_update_device(self):