Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Prevent multiple device list updates from breaking a batch send #5156

Merged
merged 25 commits into from
Jun 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/5156.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Prevent federation device list updates breaking when processing multiple updates at once.
5 changes: 3 additions & 2 deletions synapse/federation/sender/per_destination_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,10 @@ def _pop_pending_edus(self, limit):
@defer.inlineCallbacks
def _get_new_device_messages(self, limit):
last_device_list = self._last_device_list_stream_id
# Will return at most 20 entries

# Retrieve list of new device updates to send to the destination
now_stream_id, results = yield self._store.get_devices_by_remote(
self._destination, last_device_list
self._destination, last_device_list, limit=limit,
)
edus = [
Edu(
Expand Down
152 changes: 123 additions & 29 deletions synapse/storage/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
import logging

from six import iteritems, itervalues
from six import iteritems

from canonicaljson import json

Expand Down Expand Up @@ -72,67 +72,146 @@ def get_devices_by_user(self, user_id):

defer.returnValue({d["device_id"]: d for d in devices})

def get_devices_by_remote(self, destination, from_stream_id):
@defer.inlineCallbacks
def get_devices_by_remote(self, destination, from_stream_id, limit):
"""Get stream of updates to send to remote servers

Returns:
(int, list[dict]): current stream id and list of updates
Deferred[tuple[int, list[dict]]]:
current stream id (ie, the stream id of the last update included in the
response), and the list of updates
"""
now_stream_id = self._device_list_id_gen.get_current_token()

has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
)
if not has_changed:
return (now_stream_id, [])

return self.runInteraction(
defer.returnValue((now_stream_id, []))

# We retrieve n+1 devices from the list of outbound pokes where n is
# our outbound device update limit. We then check if the very last
# device has the same stream_id as the second-to-last device. If so,
# then we ignore all devices with that stream_id and only send the
# devices with a lower stream_id.
#
# If when culling the list we end up with no devices afterwards, we
# consider the device update to be too large, and simply skip the
# stream_id; the rationale being that such a large device list update
# is likely an error.
updates = yield self.runInteraction(
"get_devices_by_remote",
self._get_devices_by_remote_txn,
destination,
from_stream_id,
now_stream_id,
limit + 1,
)

# Return an empty list if there are no updates
if not updates:
defer.returnValue((now_stream_id, []))

# if we have exceeded the limit, we need to exclude any results with the
# same stream_id as the last row.
if len(updates) > limit:
stream_id_cutoff = updates[-1][2]
now_stream_id = stream_id_cutoff - 1
else:
stream_id_cutoff = None

# Perform the equivalent of a GROUP BY
#
# Iterate through the updates list and copy non-duplicate
# (user_id, device_id) entries into a map, with the value being
# the max stream_id across each set of duplicate entries
#
# maps (user_id, device_id) -> stream_id
# as long as their stream_id does not match that of the last row
query_map = {}
for update in updates:
if stream_id_cutoff is not None and update[2] >= stream_id_cutoff:
# Stop processing updates
break

key = (update[0], update[1])
query_map[key] = max(query_map.get(key, 0), update[2])

# If we didn't find any updates with a stream_id lower than the cutoff, it
# means that there are more than limit updates all of which have the same
# steam_id.

# That should only happen if a client is spamming the server with new
# devices, in which case E2E isn't going to work well anyway. We'll just
# skip that stream_id and return an empty list, and continue with the next
# stream_id next time.
if not query_map:
defer.returnValue((stream_id_cutoff, []))

results = yield self._get_device_update_edus_by_remote(
destination,
from_stream_id,
query_map,
)

defer.returnValue((now_stream_id, results))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and here.

maybe we can get rid of stream_id_cutoff altogether, and just use now_stream_id ?


def _get_devices_by_remote_txn(
self, txn, destination, from_stream_id, now_stream_id
self, txn, destination, from_stream_id, now_stream_id, limit
):
"""Return device update information for a given remote destination

Args:
txn (LoggingTransaction): The transaction to execute
destination (str): The host the device updates are intended for
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
now_stream_id (int): The maximum stream_id to filter updates by, inclusive
limit (int): Maximum number of device updates to return

Returns:
List: List of device updates
"""
sql = """
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
SELECT user_id, device_id, stream_id FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
GROUP BY user_id, device_id
LIMIT 20
ORDER BY stream_id
LIMIT ?
"""
txn.execute(sql, (destination, from_stream_id, now_stream_id, False))
txn.execute(sql, (destination, from_stream_id, now_stream_id, False, limit))

# maps (user_id, device_id) -> stream_id
query_map = {(r[0], r[1]): r[2] for r in txn}
if not query_map:
return (now_stream_id, [])
return list(txn)

if len(query_map) >= 20:
now_stream_id = max(stream_id for stream_id in itervalues(query_map))
@defer.inlineCallbacks
def _get_device_update_edus_by_remote(
self, destination, from_stream_id, query_map,
):
"""Returns a list of device update EDUs as well as E2EE keys

devices = self._get_e2e_device_keys_txn(
txn,
Args:
destination (str): The host the device updates are intended for
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
query_map (Dict[(str, str): int]): Dictionary mapping
user_id/device_id to update stream_id

Returns:
List[Dict]: List of objects representing an device update EDU

"""
devices = yield self.runInteraction(
"_get_e2e_device_keys_txn",
self._get_e2e_device_keys_txn,
query_map.keys(),
include_all_devices=True,
include_deleted_devices=True,
)

prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
FROM device_lists_outbound_last_success
WHERE destination = ? AND user_id = ? AND stream_id <= ?
"""

results = []
for user_id, user_devices in iteritems(devices):
# The prev_id for the first row is always the last row before
# `from_stream_id`
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
rows = txn.fetchall()
prev_id = rows[0][0]
prev_id = yield self._get_last_device_update_for_remote_user(
destination, user_id, from_stream_id,
)
for device_id, device in iteritems(user_devices):
stream_id = query_map[(user_id, device_id)]
result = {
Expand All @@ -156,7 +235,22 @@ def _get_devices_by_remote_txn(

results.append(result)

return (now_stream_id, results)
defer.returnValue(results)

def _get_last_device_update_for_remote_user(
self, destination, user_id, from_stream_id,
):
def f(txn):
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
FROM device_lists_outbound_last_success
WHERE destination = ? AND user_id = ? AND stream_id <= ?
"""
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
rows = txn.fetchall()
return rows[0][0]

return self.runInteraction("get_last_device_update_for_remote_user", f)

def mark_as_sent_devices_by_remote(self, destination, stream_id):
"""Mark that updates have successfully been sent to the destination.
Expand Down
69 changes: 69 additions & 0 deletions tests/storage/test_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,75 @@ def test_get_devices_by_user(self):
res["device2"],
)

@defer.inlineCallbacks
def test_get_devices_by_remote(self):
richvdh marked this conversation as resolved.
Show resolved Hide resolved
device_ids = ["device_id1", "device_id2"]

# Add two device updates with a single stream_id
yield self.store.add_device_change_to_streams(
"user_id", device_ids, ["somehost"],
)

# Get all device updates ever meant for this remote
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
"somehost", -1, limit=100,
)

# Check original device_ids are contained within these updates
self._check_devices_in_updates(device_ids, device_updates)

@defer.inlineCallbacks
def test_get_devices_by_remote_limited(self):
# Test breaking the update limit in 1, 101, and 1 device_id segments

# first add one device
device_ids1 = ["device_id0"]
yield self.store.add_device_change_to_streams(
"user_id", device_ids1, ["someotherhost"],
)

# then add 101
device_ids2 = ["device_id" + str(i + 1) for i in range(101)]
yield self.store.add_device_change_to_streams(
"user_id", device_ids2, ["someotherhost"],
)

# then one more
device_ids3 = ["newdevice"]
yield self.store.add_device_change_to_streams(
"user_id", device_ids3, ["someotherhost"],
)

#
# now read them back.
#

# first we should get a single update
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
"someotherhost", -1, limit=100,
)
self._check_devices_in_updates(device_ids1, device_updates)

# Then we should get an empty list back as the 101 devices broke the limit
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
"someotherhost", now_stream_id, limit=100,
)
self.assertEqual(len(device_updates), 0)

# The 101 devices should've been cleared, so we should now just get one device
# update
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
"someotherhost", now_stream_id, limit=100,
)
self._check_devices_in_updates(device_ids3, device_updates)

def _check_devices_in_updates(self, expected_device_ids, device_updates):
"""Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids))

received_device_ids = {update["device_id"] for update in device_updates}
self.assertEqual(received_device_ids, set(expected_device_ids))

@defer.inlineCallbacks
def test_update_device(self):
yield self.store.store_device("user_id", "device_id", "display_name 1")
Expand Down