Skip to content

Commit

Permalink
Merge pull request #298 from jamesls/deferred-socket-close
Browse files Browse the repository at this point in the history
Defer the cleanup in FakeSocket.close
  • Loading branch information
bmerry authored May 26, 2021
2 parents 5f410c4 + 34c8edf commit dcfe07e
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 8 deletions.
44 changes: 36 additions & 8 deletions fakeredis/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ class Database(MutableMapping):
def __init__(self, lock, *args, **kwargs):
self._dict = dict(*args, **kwargs)
self.time = 0.0
self._watches = defaultdict(set) # key to set of connections
self._watches = defaultdict(weakref.WeakSet) # key to set of connections
self.condition = threading.Condition(lock)
self._change_callbacks = set()

Expand Down Expand Up @@ -640,6 +640,8 @@ def __init__(self):
self.psubscribers = defaultdict(weakref.WeakSet)
self.lastsave = int(time.time())
self.connected = True
# List of weakrefs to sockets that are being closed lazily
self.closed_sockets = []


class FakeSocket:
Expand All @@ -664,7 +666,12 @@ def __init__(self, server):
self._parser.send(None)

def put_response(self, msg):
self.responses.put(msg)
# redis.Connection.__del__ might call self.close at any time, which
# will set self.responses to None. We assume this will happen
# atomically, and the code below then protects us against this.
responses = self.responses
if responses:
responses.put(msg)

def pause(self):
self._paused = True
Expand All @@ -682,13 +689,24 @@ def fileno(self):
# `FakeSelector` before it is ever used.
return 0

def _cleanup(self, server):
"""Remove all the references to `self` from `server`.
This is called with the server lock held, but it may be some time after
self.close.
"""
for subs in server.subscribers.values():
subs.discard(self)
for subs in server.psubscribers.values():
subs.discard(self)
self._clear_watches()

def close(self):
with self._server.lock:
for subs in self._server.subscribers.values():
subs.discard(self)
for subs in self._server.psubscribers.values():
subs.discard(self)
self._clear_watches()
# Mark ourselves for cleanup. This might be called from
# redis.Connection.__del__, which the garbage collection could call
# at any time, and hence we can't safely take the server lock.
# We rely on list.append being atomic.
self._server.closed_sockets.append(weakref.ref(self))
self._server = None
self._db = None
self.responses = None
Expand Down Expand Up @@ -819,6 +837,16 @@ def _process_command(self, fields):
func, func_name = self._name_to_func(fields[0])
sig = func._fakeredis_sig
with self._server.lock:
# Clean out old connections
while True:
try:
weak_sock = self._server.closed_sockets.pop()
except IndexError:
break
else:
sock = weak_sock()
if sock:
sock._cleanup(self._server)
now = time.time()
for db in self._server.dbs.values():
db.time = now
Expand Down
28 changes: 28 additions & 0 deletions test/test_fakeredis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4694,6 +4694,34 @@ def test_unlink(r):
assert r.get('foo') is None


@pytest.mark.skipif(REDIS_VERSION < "3.4", reason="Test requires redis-py 3.4+")
@pytest.mark.fake
def test_socket_cleanup_pubsub(fake_server):
r1 = fakeredis.FakeStrictRedis(server=fake_server)
r2 = fakeredis.FakeStrictRedis(server=fake_server)
ps = r1.pubsub()
with ps:
ps.subscribe('test')
ps.psubscribe('test*')
r2.publish('test', 'foo')


@pytest.mark.fake
def test_socket_cleanup_watch(fake_server):
r1 = fakeredis.FakeStrictRedis(server=fake_server)
r2 = fakeredis.FakeStrictRedis(server=fake_server)
pipeline = r1.pipeline(transaction=False)
# This needs some poking into redis-py internals to ensure that we reach
# FakeSocket._cleanup. We need to close the socket while there is still
# a watch in place, but not allow it to be garbage collected (hence we
# set 'sock' even though it is unused).
with pipeline:
pipeline.watch('test')
sock = pipeline.connection._sock # noqa: F841
pipeline.connection.disconnect()
r2.set('test', 'foo')


@redis2_only
@pytest.mark.parametrize(
'create_redis',
Expand Down

0 comments on commit dcfe07e

Please sign in to comment.