Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add cache invalidation across workers to module API (#13667)
Browse files Browse the repository at this point in the history
Signed-off-by: Mathieu Velten <[email protected]>
  • Loading branch information
Mathieu Velten authored Sep 21, 2022
1 parent 16e1a9d commit 6bd8763
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 21 deletions.
1 change: 1 addition & 0 deletions changelog.d/13667.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add cache invalidation across workers to module API.
4 changes: 2 additions & 2 deletions scripts-dev/mypy_synapse_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_method_signature_hook(
self, fullname: str
) -> Optional[Callable[[MethodSigContext], CallableType]]:
if fullname.startswith(
"synapse.util.caches.descriptors._CachedFunction.__call__"
"synapse.util.caches.descriptors.CachedFunction.__call__"
) or fullname.startswith(
"synapse.util.caches.descriptors._LruCachedFunction.__call__"
):
Expand All @@ -38,7 +38,7 @@ def get_method_signature_hook(


def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
"""Fixes the `_CachedFunction.__call__` signature to be correct.
"""Fixes the `CachedFunction.__call__` signature to be correct.
It already has *almost* the correct signature, except:
Expand Down
33 changes: 32 additions & 1 deletion synapse/module_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@
)
from synapse.util import Clock
from synapse.util.async_helpers import maybe_awaitable
from synapse.util.caches.descriptors import cached
from synapse.util.caches.descriptors import CachedFunction, cached
from synapse.util.frozenutils import freeze

if TYPE_CHECKING:
Expand Down Expand Up @@ -836,6 +836,37 @@ def run_db_interaction(
self._store.db_pool.runInteraction(desc, func, *args, **kwargs) # type: ignore[arg-type]
)

def register_cached_function(self, cached_func: CachedFunction) -> None:
"""Register a cached function that should be invalidated across workers.
Invalidation local to a worker can be done directly using `cached_func.invalidate`,
however invalidation that needs to go to other workers needs to call `invalidate_cache`
on the module API instead.
Args:
cached_function: The cached function that will be registered to receive invalidation
locally and from other workers.
"""
self._store.register_external_cached_function(
f"{cached_func.__module__}.{cached_func.__name__}", cached_func
)

async def invalidate_cache(
self, cached_func: CachedFunction, keys: Tuple[Any, ...]
) -> None:
"""Invalidate a cache entry of a cached function across workers. The cached function
needs to be registered on all workers first with `register_cached_function`.
Args:
cached_function: The cached function that needs an invalidation
keys: keys of the entry to invalidate, usually matching the arguments of the
cached function.
"""
cached_func.invalidate(keys)
await self._store.send_invalidation_to_replication(
f"{cached_func.__module__}.{cached_func.__name__}",
keys,
)

async def complete_sso_login_async(
self,
registered_user_id: str,
Expand Down
23 changes: 18 additions & 5 deletions synapse/storage/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
# limitations under the License.
import logging
from abc import ABCMeta
from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union
from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Union

from synapse.storage.database import make_in_list_sql_clause # noqa: F401; noqa: F401
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.types import get_domain_from_id
from synapse.util import json_decoder
from synapse.util.caches.descriptors import CachedFunction

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand All @@ -47,6 +48,8 @@ def __init__(
self.database_engine = database.engine
self.db_pool = database

self.external_cached_functions: Dict[str, CachedFunction] = {}

def process_replication_rows(
self,
stream_name: str,
Expand Down Expand Up @@ -95,7 +98,7 @@ def _invalidate_state_caches(

def _attempt_to_invalidate_cache(
self, cache_name: str, key: Optional[Collection[Any]]
) -> None:
) -> bool:
"""Attempts to invalidate the cache of the given name, ignoring if the
cache doesn't exist. Mainly used for invalidating caches on workers,
where they may not have the cache.
Expand All @@ -113,9 +116,12 @@ def _attempt_to_invalidate_cache(
try:
cache = getattr(self, cache_name)
except AttributeError:
# We probably haven't pulled in the cache in this worker,
# which is fine.
return
# Check if an externally defined module cache has been registered
cache = self.external_cached_functions.get(cache_name)
if not cache:
# We probably haven't pulled in the cache in this worker,
# which is fine.
return False

if key is None:
cache.invalidate_all()
Expand All @@ -125,6 +131,13 @@ def _attempt_to_invalidate_cache(
invalidate_method = getattr(cache, "invalidate_local", cache.invalidate)
invalidate_method(tuple(key))

return True

def register_external_cached_function(
self, cache_name: str, func: CachedFunction
) -> None:
self.external_cached_functions[cache_name] = func


def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
"""
Expand Down
20 changes: 14 additions & 6 deletions synapse/storage/databases/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.util.caches.descriptors import _CachedFunction
from synapse.util.caches.descriptors import CachedFunction
from synapse.util.iterutils import batch_iter

if TYPE_CHECKING:
Expand Down Expand Up @@ -269,17 +269,15 @@ async def invalidate_cache_and_stream(
return

cache_func.invalidate(keys)
await self.db_pool.runInteraction(
"invalidate_cache_and_stream",
self._send_invalidation_to_replication,
await self.send_invalidation_to_replication(
cache_func.__name__,
keys,
)

def _invalidate_cache_and_stream(
self,
txn: LoggingTransaction,
cache_func: _CachedFunction,
cache_func: CachedFunction,
keys: Tuple[Any, ...],
) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
Expand All @@ -293,7 +291,7 @@ def _invalidate_cache_and_stream(
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)

def _invalidate_all_cache_and_stream(
self, txn: LoggingTransaction, cache_func: _CachedFunction
self, txn: LoggingTransaction, cache_func: CachedFunction
) -> None:
"""Invalidates the entire cache and adds it to the cache stream so slaves
will know to invalidate their caches.
Expand Down Expand Up @@ -334,6 +332,16 @@ def _invalidate_state_caches_and_stream(
txn, CURRENT_STATE_CACHE_NAME, [room_id]
)

async def send_invalidation_to_replication(
self, cache_name: str, keys: Optional[Collection[Any]]
) -> None:
await self.db_pool.runInteraction(
"send_invalidation_to_replication",
self._send_invalidation_to_replication,
cache_name,
keys,
)

def _send_invalidation_to_replication(
self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
) -> None:
Expand Down
14 changes: 7 additions & 7 deletions synapse/util/caches/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
F = TypeVar("F", bound=Callable[..., Any])


class _CachedFunction(Generic[F]):
class CachedFunction(Generic[F]):
invalidate: Any = None
invalidate_all: Any = None
prefill: Any = None
Expand Down Expand Up @@ -242,7 +242,7 @@ def _wrapped(*args: Any, **kwargs: Any) -> Any:

return ret2

wrapped = cast(_CachedFunction, _wrapped)
wrapped = cast(CachedFunction, _wrapped)
wrapped.cache = cache
obj.__dict__[self.name] = wrapped

Expand Down Expand Up @@ -363,7 +363,7 @@ def _wrapped(*args: Any, **kwargs: Any) -> Any:

return make_deferred_yieldable(ret)

wrapped = cast(_CachedFunction, _wrapped)
wrapped = cast(CachedFunction, _wrapped)

if self.num_args == 1:
assert not self.tree
Expand Down Expand Up @@ -572,7 +572,7 @@ def cached(
iterable: bool = False,
prune_unread_entries: bool = True,
name: Optional[str] = None,
) -> Callable[[F], _CachedFunction[F]]:
) -> Callable[[F], CachedFunction[F]]:
func = lambda orig: DeferredCacheDescriptor(
orig,
max_entries=max_entries,
Expand All @@ -585,7 +585,7 @@ def cached(
name=name,
)

return cast(Callable[[F], _CachedFunction[F]], func)
return cast(Callable[[F], CachedFunction[F]], func)


def cachedList(
Expand All @@ -594,7 +594,7 @@ def cachedList(
list_name: str,
num_args: Optional[int] = None,
name: Optional[str] = None,
) -> Callable[[F], _CachedFunction[F]]:
) -> Callable[[F], CachedFunction[F]]:
"""Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
Used to do batch lookups for an already created cache. One of the arguments
Expand Down Expand Up @@ -631,7 +631,7 @@ def batch_do_something(self, first_arg, second_args):
name=name,
)

return cast(Callable[[F], _CachedFunction[F]], func)
return cast(Callable[[F], CachedFunction[F]], func)


def _get_cache_key_builder(
Expand Down
79 changes: 79 additions & 0 deletions tests/replication/test_module_cache_invalidation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging

import synapse
from synapse.module_api import cached

from tests.replication._base import BaseMultiWorkerStreamTestCase

logger = logging.getLogger(__name__)

FIRST_VALUE = "one"
SECOND_VALUE = "two"

KEY = "mykey"


class TestCache:
current_value = FIRST_VALUE

@cached()
async def cached_function(self, user_id: str) -> str:
return self.current_value


class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
servlets = [
synapse.rest.admin.register_servlets,
]

def test_module_cache_full_invalidation(self):
main_cache = TestCache()
self.hs.get_module_api().register_cached_function(main_cache.cached_function)

worker_hs = self.make_worker_hs("synapse.app.generic_worker")

worker_cache = TestCache()
worker_hs.get_module_api().register_cached_function(
worker_cache.cached_function
)

self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
self.assertEqual(
FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
)

main_cache.current_value = SECOND_VALUE
worker_cache.current_value = SECOND_VALUE
# No invalidation yet, should return the cached value on both the main process and the worker
self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
self.assertEqual(
FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
)

# Full invalidation on the main process, should be replicated on the worker that
# should returned the updated value too
self.get_success(
self.hs.get_module_api().invalidate_cache(
main_cache.cached_function, (KEY,)
)
)

self.assertEqual(
SECOND_VALUE, self.get_success(main_cache.cached_function(KEY))
)
self.assertEqual(
SECOND_VALUE, self.get_success(worker_cache.cached_function(KEY))
)

0 comments on commit 6bd8763

Please sign in to comment.