Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Port synapse.replication.tcp to async/await (#6666)
Browse files Browse the repository at this point in the history
* commit '48c3a9688':
  Port synapse.replication.tcp to async/await (#6666)
  • Loading branch information
anoadragon453 committed Mar 23, 2020
2 parents 888e203 + 48c3a96 commit 730dac5
Show file tree
Hide file tree
Showing 15 changed files with 80 additions and 105 deletions.
1 change: 1 addition & 0 deletions changelog.d/6666.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Port `synapse.replication.tcp` to async/await.
3 changes: 1 addition & 2 deletions synapse/app/admin_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ def build_tcp_replication(self):


class AdminCmdReplicationHandler(ReplicationClientHandler):
@defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows):
async def on_rdata(self, stream_name, token, rows):
pass

def get_streams_to_replicate(self):
Expand Down
5 changes: 2 additions & 3 deletions synapse/app/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,8 @@ def __init__(self, hs):
super(ASReplicationHandler, self).__init__(hs.get_datastore())
self.appservice_handler = hs.get_application_service_handler()

@defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows):
yield super(ASReplicationHandler, self).on_rdata(stream_name, token, rows)
async def on_rdata(self, stream_name, token, rows):
await super(ASReplicationHandler, self).on_rdata(stream_name, token, rows)

if stream_name == "events":
max_stream_id = self.store.get_room_max_stream_ordering()
Expand Down
5 changes: 2 additions & 3 deletions synapse/app/federation_sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,8 @@ def __init__(self, hs):
super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore())
self.send_handler = FederationSenderHandler(hs, self)

@defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows):
yield super(FederationSenderReplicationHandler, self).on_rdata(
async def on_rdata(self, stream_name, token, rows):
await super(FederationSenderReplicationHandler, self).on_rdata(
stream_name, token, rows
)
self.send_handler.process_replication_rows(stream_name, token, rows)
Expand Down
5 changes: 2 additions & 3 deletions synapse/app/pusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,8 @@ def __init__(self, hs):

self.pusher_pool = hs.get_pusherpool()

@defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows):
yield super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
async def on_rdata(self, stream_name, token, rows):
await super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
run_in_background(self.poke_pushers, stream_name, token, rows)

@defer.inlineCallbacks
Expand Down
5 changes: 2 additions & 3 deletions synapse/app/synchrotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,8 @@ def __init__(self, hs):
self.presence_handler = hs.get_presence_handler()
self.notifier = hs.get_notifier()

@defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows):
yield super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
async def on_rdata(self, stream_name, token, rows):
await super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
run_in_background(self.process_and_notify, stream_name, token, rows)

def get_streams_to_replicate(self):
Expand Down
5 changes: 2 additions & 3 deletions synapse/app/user_dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,8 @@ def __init__(self, hs):
super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore())
self.user_directory = hs.get_user_directory_handler()

@defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows):
yield super(UserDirectoryReplicationHandler, self).on_rdata(
async def on_rdata(self, stream_name, token, rows):
await super(UserDirectoryReplicationHandler, self).on_rdata(
stream_name, token, rows
)
if stream_name == EventsStream.NAME:
Expand Down
4 changes: 3 additions & 1 deletion synapse/federation/send_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,9 @@ def get_current_token(self):
def federation_ack(self, token):
self._clear_queue_before_pos(token)

def get_replication_rows(self, from_token, to_token, limit, federation_ack=None):
async def get_replication_rows(
self, from_token, to_token, limit, federation_ack=None
):
"""Get rows to be sent over federation between the two tokens
Args:
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def _push_update_local(self, member, typing):
"typing_key", self._latest_room_serial, rooms=[member.room_id]
)

def get_all_typing_updates(self, last_id, current_id):
async def get_all_typing_updates(self, last_id, current_id):
if last_id == current_id:
return []

Expand Down
11 changes: 4 additions & 7 deletions synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def start_replication(self, hs):
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self.factory)

def on_rdata(self, stream_name, token, rows):
async def on_rdata(self, stream_name, token, rows):
"""Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to
Expand All @@ -121,20 +121,17 @@ def on_rdata(self, stream_name, token, rows):
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
Returns:
Deferred|None
"""
logger.debug("Received rdata %s -> %s", stream_name, token)
return self.store.process_replication_rows(stream_name, token, rows)
self.store.process_replication_rows(stream_name, token, rows)

def on_position(self, stream_name, token):
async def on_position(self, stream_name, token):
"""Called when we get new position data. By default this just pokes
the slave store.
Can be overriden in subclasses to handle more.
"""
return self.store.process_replication_rows(stream_name, token, [])
self.store.process_replication_rows(stream_name, token, [])

def on_sync(self, data):
"""When we received a SYNC we wake up any deferreds that were waiting
Expand Down
72 changes: 32 additions & 40 deletions synapse/replication/tcp/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,11 @@
SyncCommand,
UserSyncCommand,
)
from synapse.replication.tcp.streams import STREAMS_MAP
from synapse.types import Collection
from synapse.util import Clock
from synapse.util.stringutils import random_string

from .streams import STREAMS_MAP

connection_close_counter = Counter(
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
)
Expand Down Expand Up @@ -241,19 +240,16 @@ def lineReceived(self, line):
"replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
)

def handle_command(self, cmd):
async def handle_command(self, cmd: Command):
"""Handle a command we have received over the replication stream.
By default delegates to on_<COMMAND>
By default delegates to on_<COMMAND>, which should return an awaitable.
Args:
cmd (synapse.replication.tcp.commands.Command): received command
Returns:
Deferred
cmd: received command
"""
handler = getattr(self, "on_%s" % (cmd.NAME,))
return handler(cmd)
await handler(cmd)

def close(self):
logger.warning("[%s] Closing connection", self.id())
Expand Down Expand Up @@ -326,10 +322,10 @@ def _send_pending_commands(self):
for cmd in pending:
self.send_command(cmd)

def on_PING(self, line):
async def on_PING(self, line):
self.received_ping = True

def on_ERROR(self, cmd):
async def on_ERROR(self, cmd):
logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)

def pauseProducing(self):
Expand Down Expand Up @@ -429,16 +425,16 @@ def connectionMade(self):
BaseReplicationStreamProtocol.connectionMade(self)
self.streamer.new_connection(self)

def on_NAME(self, cmd):
async def on_NAME(self, cmd):
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
self.name = cmd.data

def on_USER_SYNC(self, cmd):
return self.streamer.on_user_sync(
async def on_USER_SYNC(self, cmd):
await self.streamer.on_user_sync(
self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)

def on_REPLICATE(self, cmd):
async def on_REPLICATE(self, cmd):
stream_name = cmd.stream_name
token = cmd.token

Expand All @@ -449,23 +445,23 @@ def on_REPLICATE(self, cmd):
for stream in iterkeys(self.streamer.streams_by_name)
]

return make_deferred_yieldable(
await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
else:
return self.subscribe_to_stream(stream_name, token)
await self.subscribe_to_stream(stream_name, token)

def on_FEDERATION_ACK(self, cmd):
return self.streamer.federation_ack(cmd.token)
async def on_FEDERATION_ACK(self, cmd):
self.streamer.federation_ack(cmd.token)

def on_REMOVE_PUSHER(self, cmd):
return self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
async def on_REMOVE_PUSHER(self, cmd):
await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)

def on_INVALIDATE_CACHE(self, cmd):
return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
async def on_INVALIDATE_CACHE(self, cmd):
self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)

def on_USER_IP(self, cmd):
return self.streamer.on_user_ip(
async def on_USER_IP(self, cmd):
self.streamer.on_user_ip(
cmd.user_id,
cmd.access_token,
cmd.ip,
Expand All @@ -474,8 +470,7 @@ def on_USER_IP(self, cmd):
cmd.last_seen,
)

@defer.inlineCallbacks
def subscribe_to_stream(self, stream_name, token):
async def subscribe_to_stream(self, stream_name, token):
"""Subscribe the remote to a stream.
This invloves checking if they've missed anything and sending those
Expand All @@ -487,7 +482,7 @@ def subscribe_to_stream(self, stream_name, token):

try:
# Get missing updates
updates, current_token = yield self.streamer.get_stream_updates(
updates, current_token = await self.streamer.get_stream_updates(
stream_name, token
)

Expand Down Expand Up @@ -572,22 +567,19 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
"""

@abc.abstractmethod
def on_rdata(self, stream_name, token, rows):
async def on_rdata(self, stream_name, token, rows):
"""Called to handle a batch of replication data with a given stream token.
Args:
stream_name (str): name of the replication stream for this batch of rows
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
Returns:
Deferred|None
"""
raise NotImplementedError()

@abc.abstractmethod
def on_position(self, stream_name, token):
async def on_position(self, stream_name, token):
"""Called when we get new position data."""
raise NotImplementedError()

Expand Down Expand Up @@ -676,12 +668,12 @@ def connectionMade(self):
if not self.streams_connecting:
self.handler.finished_connecting()

def on_SERVER(self, cmd):
async def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
self.send_error("Wrong remote")

def on_RDATA(self, cmd):
async def on_RDATA(self, cmd):
stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()

Expand All @@ -701,19 +693,19 @@ def on_RDATA(self, cmd):
# Check if this is the last of a batch of updates
rows = self.pending_batches.pop(stream_name, [])
rows.append(row)
return self.handler.on_rdata(stream_name, cmd.token, rows)
await self.handler.on_rdata(stream_name, cmd.token, rows)

def on_POSITION(self, cmd):
async def on_POSITION(self, cmd):
# When we get a `POSITION` command it means we've finished getting
# missing updates for the given stream, and are now up to date.
self.streams_connecting.discard(cmd.stream_name)
if not self.streams_connecting:
self.handler.finished_connecting()

return self.handler.on_position(cmd.stream_name, cmd.token)
await self.handler.on_position(cmd.stream_name, cmd.token)

def on_SYNC(self, cmd):
return self.handler.on_sync(cmd.data)
async def on_SYNC(self, cmd):
self.handler.on_sync(cmd.data)

def replicate(self, stream_name, token):
"""Send the subscription request to the server
Expand Down
Loading

0 comments on commit 730dac5

Please sign in to comment.