Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing read race condition during pubsub #1737

Merged
merged 11 commits into from
Dec 23, 2021
74 changes: 68 additions & 6 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,18 +1276,17 @@ def __init__(
self.shard_hint = shard_hint
self.ignore_subscribe_messages = ignore_subscribe_messages
self.connection = None
self.subscribed_event = threading.Event()
# we need to know the encoding options for this connection in order
# to lookup channel and pattern names for callback handlers.
self.encoder = encoder
if self.encoder is None:
self.encoder = self.connection_pool.get_encoder()
self.health_check_response_b = self.encoder.encode(self.HEALTH_CHECK_MESSAGE)
if self.encoder.decode_responses:
self.health_check_response = ["pong", self.HEALTH_CHECK_MESSAGE]
else:
self.health_check_response = [
b"pong",
self.encoder.encode(self.HEALTH_CHECK_MESSAGE),
]
self.health_check_response = [b"pong", self.health_check_response_b]
self.reset()

def __enter__(self):
Expand All @@ -1312,9 +1311,11 @@ def reset(self):
self.connection_pool.release(self.connection)
self.connection = None
self.channels = {}
self.health_check_response_counter = 0
self.pending_unsubscribe_channels = set()
self.patterns = {}
self.pending_unsubscribe_patterns = set()
self.subscribed_event.clear()

def close(self):
self.reset()
Expand All @@ -1340,7 +1341,7 @@ def on_connect(self, connection):
@property
def subscribed(self):
"Indicates if there are subscriptions to any channels or patterns"
return bool(self.channels or self.patterns)
return self.subscribed_event.is_set()

def execute_command(self, *args):
"Execute a publish/subscribe command"
Expand All @@ -1358,8 +1359,28 @@ def execute_command(self, *args):
self.connection.register_connect_callback(self.on_connect)
connection = self.connection
kwargs = {"check_health": not self.subscribed}
if not self.subscribed:
self.clean_health_check_responses()
self._execute(connection, connection.send_command, *args, **kwargs)

def clean_health_check_responses(self):
"""
If any health check responses are present, clean them
"""
ttl = 10
conn = self.connection
while self.health_check_response_counter > 0 and ttl > 0:
if self._execute(conn, conn.can_read, timeout=conn.socket_timeout):
response = self._execute(conn, conn.read_response)
if self.is_health_check_response(response):
self.health_check_response_counter -= 1
else:
raise PubSubError(
"A non health check response was cleaned by "
"execute_command: {0}".format(response)
)
ttl -= 1

def _disconnect_raise_connect(self, conn, error):
"""
Close the connection and raise an exception
Expand Down Expand Up @@ -1399,11 +1420,23 @@ def parse_response(self, block=True, timeout=0):
return None
response = self._execute(conn, conn.read_response)

if conn.health_check_interval and response == self.health_check_response:
if self.is_health_check_response(response):
# ignore the health check message as user might not expect it
self.health_check_response_counter -= 1
return None
return response

def is_health_check_response(self, response):
"""
Check if the response is a health check response.
If there are no subscriptions redis responds to PING command with a
bulk response, instead of a multi-bulk with "pong" and the response.
"""
return response in [
self.health_check_response, # If there was a subscription
self.health_check_response_b, # If there wasn't
]

def check_health(self):
conn = self.connection
if conn is None:
Expand All @@ -1414,6 +1447,7 @@ def check_health(self):

if conn.health_check_interval and time.time() > conn.next_health_check:
conn.send_command("PING", self.HEALTH_CHECK_MESSAGE, check_health=False)
self.health_check_response_counter += 1

def _normalize_keys(self, data):
"""
Expand Down Expand Up @@ -1443,6 +1477,11 @@ def psubscribe(self, *args, **kwargs):
# for the reconnection.
new_patterns = self._normalize_keys(new_patterns)
self.patterns.update(new_patterns)
if not self.subscribed:
# Set the subscribed_event flag to True
self.subscribed_event.set()
# Clear the health check counter
self.health_check_response_counter = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be done at the end of clean_health_check_response?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to put it there because the following scenario is possible:

  1. p.subscribe("foo")
  2. health check is performed
  3. p.unsubscribe("foo")
  4. a health check response still hasn't received
  5. p.unsuscribe("foo")
  6. clean_health_check_response is being called by the unsubscribe command, the health check response hasn't arrived yet and it exists the loop due to ttl runs-out
  7. the health check response only now received
  8. p.subscribe() is being called - self.subscribed is still False so a health check will be performed and we should clean the existing health check response before we continue.
  9. If we add 'clean_health_check_response=0' at the end of clean_health_check, we will clean the counter in step 6, so we won't be able to clean the socket from the response on step 8.

self.pending_unsubscribe_patterns.difference_update(new_patterns)
return ret_val

Expand Down Expand Up @@ -1477,6 +1516,11 @@ def subscribe(self, *args, **kwargs):
# for the reconnection.
new_channels = self._normalize_keys(new_channels)
self.channels.update(new_channels)
if not self.subscribed:
# Set the subscribed_event flag to True
self.subscribed_event.set()
# Clear the health check counter
self.health_check_response_counter = 0
self.pending_unsubscribe_channels.difference_update(new_channels)
return ret_val

Expand Down Expand Up @@ -1508,6 +1552,20 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0):
before returning. Timeout should be specified as a floating point
number.
"""
if not self.subscribed:
# Wait for subscription
start_time = time.time()
if self.subscribed_event.wait(timeout) is True:
barshaul marked this conversation as resolved.
Show resolved Hide resolved
# The connection was subscribed during the timeout time frame.
# The timeout should be adjusted based on the time spent
# waiting for the subscription
time_spent = time.time() - start_time
timeout = max(0.0, timeout - time_spent)
else:
# The connection isn't subscribed to any channels or patterns,
# so no messages are available
return None

response = self.parse_response(block=False, timeout=timeout)
barshaul marked this conversation as resolved.
Show resolved Hide resolved
if response:
return self.handle_message(response, ignore_subscribe_messages)
Expand Down Expand Up @@ -1561,6 +1619,10 @@ def handle_message(self, response, ignore_subscribe_messages=False):
if channel in self.pending_unsubscribe_channels:
self.pending_unsubscribe_channels.remove(channel)
self.channels.pop(channel, None)
if not self.channels and not self.patterns:
# There are no subscriptions anymore, set subscribed_event flag
# to false
self.subscribed_event.clear()

if message_type in self.PUBLISH_MESSAGE_TYPES:
# if there's a message handler, invoke it
Expand Down
43 changes: 34 additions & 9 deletions tests/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import threading
import time
from unittest import mock
from unittest.mock import patch

import pytest

Expand Down Expand Up @@ -348,15 +349,6 @@ def test_unicode_pattern_message_handler(self, r):
"pmessage", channel, "test message", pattern=pattern
)

def test_get_message_without_subscribe(self, r):
p = r.pubsub()
with pytest.raises(RuntimeError) as info:
p.get_message()
expect = (
"connection not set: " "did you forget to call subscribe() or psubscribe()?"
)
assert expect in info.exconly()


class TestPubSubAutoDecoding:
"These tests only validate that we get unicode values back"
Expand Down Expand Up @@ -549,6 +541,39 @@ def test_get_message_with_timeout_returns_none(self, r):
assert wait_for_message(p) == make_message("subscribe", "foo", 1)
assert p.get_message(timeout=0.01) is None

def test_get_message_not_subscribed_return_none(self, r):
p = r.pubsub()
assert p.subscribed is False
assert p.get_message() is None
assert p.get_message(timeout=0.1) is None
with patch.object(threading.Event, "wait") as mock:
mock.return_value = False
assert p.get_message(timeout=0.01) is None
assert mock.called

def test_get_message_subscribe_during_waiting(self, r):
p = r.pubsub()

def poll(ps, expected_res):
assert ps.get_message() is None
message = ps.get_message(timeout=1)
assert message == expected_res

subscribe_response = make_message("subscribe", "foo", 1)
poller = threading.Thread(target=poll, args=(p, subscribe_response))
poller.start()
time.sleep(0.2)
p.subscribe("foo")
poller.join()

def test_get_message_wait_for_subscription_not_being_called(self, r):
p = r.pubsub()
p.subscribe("foo")
with patch.object(threading.Event, "wait") as mock:
assert p.subscribed is True
assert wait_for_message(p) == make_message("subscribe", "foo", 1)
assert mock.called is False


class TestPubSubWorkerThread:
@pytest.mark.skipif(
Expand Down