Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Fix some error cases in the caching layer. (#5749)
Browse files Browse the repository at this point in the history
  • Loading branch information
anoadragon453 committed Feb 19, 2020
2 parents f7bf143 + 618bd1e commit c544308
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 35 deletions.
1 change: 1 addition & 0 deletions changelog.d/5749.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix some error cases in the caching layer.
74 changes: 42 additions & 32 deletions synapse/util/caches/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
import threading
from collections import namedtuple

import six
from six import itervalues, string_types
from six import itervalues

from prometheus_client import Gauge

Expand All @@ -32,7 +31,6 @@
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from synapse.util.stringutils import to_ascii

from . import register_cache

Expand Down Expand Up @@ -124,7 +122,7 @@ def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True):
update_metrics (bool): whether to update the cache hit rate metrics
Returns:
Either a Deferred or the raw result
Either an ObservableDeferred or the raw result
"""
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel)
Expand All @@ -148,40 +146,63 @@ def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True):
return default

def set(self, key, value, callback=None):
if not isinstance(value, defer.Deferred):
raise TypeError("not a Deferred")

callbacks = [callback] if callback else []
self.check_thread()
entry = CacheEntry(deferred=value, callbacks=callbacks)
observable = ObservableDeferred(value, consumeErrors=True)
observer = defer.maybeDeferred(observable.observe)
entry = CacheEntry(deferred=observable, callbacks=callbacks)

existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
existing_entry.invalidate()

self._pending_deferred_cache[key] = entry

def shuffle(result):
def compare_and_pop():
"""Check if our entry is still the one in _pending_deferred_cache, and
if so, pop it.
Returns true if the entries matched.
"""
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
return True

# oops, the _pending_deferred_cache has been updated since
# we started our query, so we are out of date.
#
# Better put back whatever we took out. (We do it this way
# round, rather than peeking into the _pending_deferred_cache
# and then removing on a match, to make the common case faster)
if existing_entry is not None:
self._pending_deferred_cache[key] = existing_entry

return False

def cb(result):
if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
# oops, the _pending_deferred_cache has been updated since
# we started our query, so we are out of date.
#
# Better put back whatever we took out. (We do it this way
# round, rather than peeking into the _pending_deferred_cache
# and then removing on a match, to make the common case faster)
if existing_entry is not None:
self._pending_deferred_cache[key] = existing_entry

# we're not going to put this entry into the cache, so need
# to make sure that the invalidation callbacks are called.
# That was probably done when _pending_deferred_cache was
# updated, but it's possible that `set` was called without
# `invalidate` being previously called, in which case it may
# not have been. Either way, let's double-check now.
entry.invalidate()
return result

entry.deferred.addCallback(shuffle)
def eb(_fail):
compare_and_pop()
entry.invalidate()

# once the deferred completes, we can move the entry from the
# _pending_deferred_cache to the real cache.
#
observer.addCallbacks(cb, eb)
return observable

def prefill(self, key, value, callback=None):
callbacks = [callback] if callback else []
Expand Down Expand Up @@ -414,20 +435,10 @@ def onErr(f):

ret.addErrback(onErr)

# If our cache_key is a string on py2, try to convert to ascii
# to save a bit of space in large caches. Py3 does this
# internally automatically.
if six.PY2 and isinstance(cache_key, string_types):
cache_key = to_ascii(cache_key)

result_d = ObservableDeferred(ret, consumeErrors=True)
cache.set(cache_key, result_d, callback=invalidate_callback)
result_d = cache.set(cache_key, ret, callback=invalidate_callback)
observer = result_d.observe()

if isinstance(observer, defer.Deferred):
return make_deferred_yieldable(observer)
else:
return observer
return make_deferred_yieldable(observer)

if self.num_args == 1:
wrapped.invalidate = lambda key: cache.invalidate(key[0])
Expand Down Expand Up @@ -543,16 +554,15 @@ def arg_to_cache_key(arg):
missing.add(arg)

if missing:
# we need an observable deferred for each entry in the list,
# we need a deferred for each entry in the list,
# which we put in the cache. Each deferred resolves with the
# relevant result for that key.
deferreds_map = {}
for arg in missing:
deferred = defer.Deferred()
deferreds_map[arg] = deferred
key = arg_to_cache_key(arg)
observable = ObservableDeferred(deferred)
cache.set(key, observable, callback=invalidate_callback)
cache.set(key, deferred, callback=invalidate_callback)

def complete_all(res):
# the wrapped function has completed. It returns a
Expand Down
90 changes: 87 additions & 3 deletions tests/util/caches/test_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
make_deferred_yieldable,
)
from synapse.util.caches import descriptors
from synapse.util.caches.descriptors import cached

from tests import unittest

Expand Down Expand Up @@ -55,12 +56,15 @@ def record_callback(idx):
d2 = defer.Deferred()
cache.set("key2", d2, partial(record_callback, 1))

# lookup should return the deferreds
self.assertIs(cache.get("key1"), d1)
self.assertIs(cache.get("key2"), d2)
# lookup should return observable deferreds
self.assertFalse(cache.get("key1").has_called())
self.assertFalse(cache.get("key2").has_called())

# let one of the lookups complete
d2.callback("result2")

# for now at least, the cache will return real results rather than an
# observabledeferred
self.assertEqual(cache.get("key2"), "result2")

# now do the invalidation
Expand Down Expand Up @@ -146,6 +150,28 @@ def fn(self, arg1, arg2):
self.assertEqual(r, "chips")
obj.mock.assert_not_called()

def test_cache_with_sync_exception(self):
"""If the wrapped function throws synchronously, things should continue to work
"""

class Cls(object):
@cached()
def fn(self, arg1):
raise SynapseError(100, "mai spoon iz too big!!1")

obj = Cls()

# this should fail immediately
d = obj.fn(1)
self.failureResultOf(d, SynapseError)

# ... leaving the cache empty
self.assertEqual(len(obj.fn.cache.cache), 0)

# and a second call should result in a second exception
d = obj.fn(1)
self.failureResultOf(d, SynapseError)

def test_cache_logcontexts(self):
"""Check that logcontexts are set and restored correctly when
using the cache."""
Expand Down Expand Up @@ -222,6 +248,9 @@ def do_lookup():

self.assertEqual(LoggingContext.current_context(), c1)

# the cache should now be empty
self.assertEqual(len(obj.fn.cache.cache), 0)

obj = Cls()

# set off a deferred which will do a cache lookup
Expand Down Expand Up @@ -268,6 +297,61 @@ def fn(self, arg1, arg2=2, arg3=3):
self.assertEqual(r, "chips")
obj.mock.assert_not_called()

def test_cache_iterable(self):
class Cls(object):
def __init__(self):
self.mock = mock.Mock()

@descriptors.cached(iterable=True)
def fn(self, arg1, arg2):
return self.mock(arg1, arg2)

obj = Cls()

obj.mock.return_value = ["spam", "eggs"]
r = obj.fn(1, 2)
self.assertEqual(r, ["spam", "eggs"])
obj.mock.assert_called_once_with(1, 2)
obj.mock.reset_mock()

# a call with different params should call the mock again
obj.mock.return_value = ["chips"]
r = obj.fn(1, 3)
self.assertEqual(r, ["chips"])
obj.mock.assert_called_once_with(1, 3)
obj.mock.reset_mock()

# the two values should now be cached
self.assertEqual(len(obj.fn.cache.cache), 3)

r = obj.fn(1, 2)
self.assertEqual(r, ["spam", "eggs"])
r = obj.fn(1, 3)
self.assertEqual(r, ["chips"])
obj.mock.assert_not_called()

def test_cache_iterable_with_sync_exception(self):
"""If the wrapped function throws synchronously, things should continue to work
"""

class Cls(object):
@descriptors.cached(iterable=True)
def fn(self, arg1):
raise SynapseError(100, "mai spoon iz too big!!1")

obj = Cls()

# this should fail immediately
d = obj.fn(1)
self.failureResultOf(d, SynapseError)

# ... leaving the cache empty
self.assertEqual(len(obj.fn.cache.cache), 0)

# and a second call should result in a second exception
d = obj.fn(1)
self.failureResultOf(d, SynapseError)


class CachedListDescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
Expand Down

0 comments on commit c544308

Please sign in to comment.