Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Periodically send pings to detect dead Redis connections #9218

Merged
merged 7 commits into from
Jan 26, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/9218.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug where we sometimes didn't detect that Redis connections had died, causing workers to not see new data.
12 changes: 8 additions & 4 deletions stubs/txredisapi.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ from typing import List, Optional, Type, Union

class RedisProtocol:
def publish(self, channel: str, message: bytes): ...
async def ping(self) -> None: ...

class SubscriberProtocol:
class SubscriberProtocol(RedisProtocol):
def __init__(self, *args, **kwargs): ...
password: Optional[str]
def subscribe(self, channels: Union[str, List[str]]): ...
Expand All @@ -40,14 +41,13 @@ def lazyConnection(
convertNumbers: bool = ...,
) -> RedisProtocol: ...

class SubscriberFactory:
def buildProtocol(self, addr): ...

class ConnectionHandler: ...

class RedisFactory:
continueTrying: bool
handler: RedisProtocol
pool: List[RedisProtocol]
replyTimeout: Optional[int]
def __init__(
self,
uuid: str,
Expand All @@ -60,3 +60,7 @@ class RedisFactory:
replyTimeout: Optional[int] = None,
convertNumbers: Optional[int] = True,
): ...
def buildProtocol(self, addr) -> RedisProtocol: ...

class SubscriberFactory(RedisFactory):
def __init__(self): ...
8 changes: 6 additions & 2 deletions synapse/replication/tcp/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
import logging
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Dict,
Expand Down Expand Up @@ -63,6 +64,9 @@
TypingStream,
)

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


Expand All @@ -88,7 +92,7 @@ class ReplicationCommandHandler:
back out to connections.
"""

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self._replication_data_handler = hs.get_replication_data_handler()
self._presence_handler = hs.get_presence_handler()
self._store = hs.get_datastore()
Expand Down Expand Up @@ -300,7 +304,7 @@ def start_replication(self, hs):

# First create the connection for sending commands.
outbound_redis_connection = lazyConnection(
reactor=hs.get_reactor(),
hs=hs,
host=hs.config.redis_host,
port=hs.config.redis_port,
password=hs.config.redis.redis_password,
Expand Down
117 changes: 86 additions & 31 deletions synapse/replication/tcp/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@

import logging
from inspect import isawaitable
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Type, cast

import txredisapi

from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import (
BackgroundProcessLoggingContext,
run_as_background_process,
wrap_as_background_process,
)
from synapse.replication.tcp.commands import (
Command,
Expand Down Expand Up @@ -59,16 +60,16 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
immediately after initialisation.

Attributes:
handler: The command handler to handle incoming commands.
stream_name: The *redis* stream name to subscribe to and publish from
(not anything to do with Synapse replication streams).
outbound_redis_connection: The connection to redis to use to send
synapse_handler: The command handler to handle incoming commands.
synapse_stream_name: The *redis* stream name to subscribe to and publish
from (not anything to do with Synapse replication streams).
synapse_outbound_redis_connection: The connection to redis to use to send
commands.
"""

handler = None # type: ReplicationCommandHandler
stream_name = None # type: str
outbound_redis_connection = None # type: txredisapi.RedisProtocol
synapse_handler = None # type: ReplicationCommandHandler
synapse_stream_name = None # type: str
synapse_outbound_redis_connection = None # type: txredisapi.RedisProtocol

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand All @@ -88,19 +89,19 @@ async def _send_subscribe(self):
# it's important to make sure that we only send the REPLICATE command once we
# have successfully subscribed to the stream - otherwise we might miss the
# POSITION response sent back by the other end.
logger.info("Sending redis SUBSCRIBE for %s", self.stream_name)
await make_deferred_yieldable(self.subscribe(self.stream_name))
logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name)
await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
logger.info(
"Successfully subscribed to redis stream, sending REPLICATE command"
)
self.handler.new_connection(self)
self.synapse_handler.new_connection(self)
await self._async_send_command(ReplicateCommand())
logger.info("REPLICATE successfully sent")

# We send out our positions when there is a new connection in case the
# other side missed updates. We do this for Redis connections as the
# otherside won't know we've connected and so won't issue a REPLICATE.
self.handler.send_positions_to_connection(self)
self.synapse_handler.send_positions_to_connection(self)

def messageReceived(self, pattern: str, channel: str, message: str):
"""Received a message from redis.
Expand Down Expand Up @@ -137,7 +138,7 @@ def handle_command(self, cmd: Command) -> None:
cmd: received command
"""

cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
cmd_func = getattr(self.synapse_handler, "on_%s" % (cmd.NAME,), None)
if not cmd_func:
logger.warning("Unhandled command: %r", cmd)
return
Expand All @@ -155,7 +156,7 @@ def handle_command(self, cmd: Command) -> None:
def connectionLost(self, reason):
logger.info("Lost connection to redis")
super().connectionLost(reason)
self.handler.lost_connection(self)
self.synapse_handler.lost_connection(self)

# mark the logging context as finished
self._logging_context.__exit__(None, None, None)
Expand Down Expand Up @@ -183,11 +184,59 @@ async def _async_send_command(self, cmd: Command):
tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()

await make_deferred_yieldable(
self.outbound_redis_connection.publish(self.stream_name, encoded_string)
self.synapse_outbound_redis_connection.publish(
self.synapse_stream_name, encoded_string
)
)


class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
clokep marked this conversation as resolved.
Show resolved Hide resolved
class SynapseRedisFactory(txredisapi.RedisFactory):
"""A subclass of RedisFactory that ensures that periodically sends pings
to ensure that we detect dead connections.
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
self,
hs: "HomeServer",
uuid: str,
dbid: Optional[int],
poolsize: int,
isLazy: bool = False,
handler: Type = txredisapi.ConnectionHandler,
charset: str = "utf-8",
password: Optional[str] = None,
replyTimeout: Optional[int] = None,
convertNumbers: Optional[int] = True,
):
# We want to ensure that we timeout when sending pings on dead
# connections, rather than just hanging.
if replyTimeout is None:
replyTimeout = 30
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved

super().__init__(
uuid,
dbid,
poolsize,
isLazy,
handler,
charset,
password,
replyTimeout,
convertNumbers,
)

hs.get_clock().looping_call(self._send_ping, 30 * 1000)

@wrap_as_background_process("redis_ping")
async def _send_ping(self):
for connection in self.pool:
try:
await make_deferred_yieldable(connection.ping())
except Exception:
logger.warning("Failed to send ping to a redis connection")


class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
"""This is a reconnecting factory that connects to redis and immediately
subscribes to a stream.

Expand All @@ -206,34 +255,37 @@ def __init__(
self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol
):

super().__init__()

# This sets the password on the RedisFactory base class (as
# SubscriberFactory constructor doesn't pass it through).
self.password = hs.config.redis.redis_password
super().__init__(
hs,
"subscriber",
None,
1,
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
replyTimeout=30,
password=hs.config.redis.redis_password,
)

self.handler = hs.get_tcp_replication()
self.stream_name = hs.hostname
self.synapse_handler = hs.get_tcp_replication()
self.synapse_stream_name = hs.hostname

self.outbound_redis_connection = outbound_redis_connection
self.synapse_outbound_redis_connection = outbound_redis_connection

def buildProtocol(self, addr):
p = super().buildProtocol(addr) # type: RedisSubscriber
p = super().buildProtocol(addr)
p = cast(RedisSubscriber, p)

# We do this here rather than add to the constructor of `RedisSubcriber`
# as to do so would involve overriding `buildProtocol` entirely, however
# the base method does some other things than just instantiating the
# protocol.
p.handler = self.handler
p.outbound_redis_connection = self.outbound_redis_connection
p.stream_name = self.stream_name
p.password = self.password
p.synapse_handler = self.synapse_handler
p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection
p.synapse_stream_name = self.synapse_stream_name

return p


def lazyConnection(
reactor,
hs: "HomeServer",
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
host: str = "localhost",
port: int = 6379,
dbid: Optional[int] = None,
Expand All @@ -252,7 +304,8 @@ def lazyConnection(
poolsize = 1

uuid = "%s:%d" % (host, port)
factory = txredisapi.RedisFactory(
factory = SynapseRedisFactory(
hs,
uuid,
dbid,
poolsize,
Expand All @@ -264,6 +317,8 @@ def lazyConnection(
convertNumbers,
)
factory.continueTrying = reconnect

reactor = hs.get_reactor()
for x in range(poolsize):
reactor.connectTCP(host, port, factory, connectTimeout)

Expand Down