From 375c1ef864c35adcc6c3924e76bbd47f5cc28d0e Mon Sep 17 00:00:00 2001 From: Daniel M Date: Sat, 19 Oct 2024 11:46:47 -0400 Subject: [PATCH] support for server type specific commands (#340) --- docs/about/changelog.md | 7 +++- docs/index.md | 15 ++++--- docs/valkey-support.md | 19 +++++++++ fakeredis/_basefakesocket.py | 4 ++ fakeredis/_commands.py | 2 + fakeredis/_fakesocket.py | 42 ++++++++----------- fakeredis/_server.py | 6 ++- fakeredis/_tcp_server.py | 3 +- fakeredis/commands_mixins/__init__.py | 41 ++++++++++++++++++ fakeredis/commands_mixins/connection_mixin.py | 2 +- fakeredis/commands_mixins/generic_mixin.py | 6 +-- .../server_specific_commands/__init__.py | 5 +++ .../dragonfly_mixin.py | 20 +++++++++ pyproject.toml | 2 +- test/test_mixins/test_server_commands.py | 10 +++-- 15 files changed, 140 insertions(+), 44 deletions(-) create mode 100644 fakeredis/server_specific_commands/__init__.py create mode 100644 fakeredis/server_specific_commands/dragonfly_mixin.py diff --git a/docs/about/changelog.md b/docs/about/changelog.md index 11b87908..e95acc41 100644 --- a/docs/about/changelog.md +++ b/docs/about/changelog.md @@ -7,7 +7,12 @@ tags: toc_depth: 2 --- -## v2.25.2 +## v2.26.0 + +### 🚀 Features + +- Support for server-type specific commands #340 +- Support for Dragonfly `SADDEX` command #340 ### 🐛 Bug Fixes diff --git a/docs/index.md b/docs/index.md index 4fb46470..4e909ba9 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,6 +1,6 @@ ---- +from test.test_hypothesis import server_typefrom test.test_hypothesis import server_type--- toc: - toc_depth: 3 +toc_depth: 3 --- fakeredis: A python implementation of redis server @@ -46,7 +46,7 @@ from threading import Thread from fakeredis import TcpFakeServer server_address = ("127.0.0.1", 6379) -server = TcpFakeServer(server_address) +server = TcpFakeServer(server_address, server_type="redis") t = Thread(target=server.serve_forever, daemon=True) t.start() @@ -73,14 +73,15 @@ def redis_client(request): ### General usage -FakeRedis can imitate Redis server version 6.x or 7.x. Version 7 is used by default. +FakeRedis can imitate Redis server version 6.x or 7.x, [Valkey server](./valkey-support), +and [dragonfly server][dragonfly]. Redis version 7 is used by default. The intent is for fakeredis to act as though you're talking to a real redis server. It does this by storing the state internally. For example: ```pycon >>> import fakeredis ->>> r = fakeredis.FakeStrictRedis(version=6) +>>> r = fakeredis.FakeStrictRedis(server_type="redis") >>> r.set('foo', 'bar') True >>> r.get('foo') @@ -391,4 +392,6 @@ You can support this project by becoming a sponsor using [this link][2]. [8]:https://github.com/jazzband/django-redis -[9]:https://docs.djangoproject.com/en/4.1/topics/testing/tools/#django.test.override_settings \ No newline at end of file +[9]:https://docs.djangoproject.com/en/4.1/topics/testing/tools/#django.test.override_settings + +[dragonfly]:https://www.dragonflydb.io/ diff --git a/docs/valkey-support.md b/docs/valkey-support.md index fd7de8e6..00169a3d 100644 --- a/docs/valkey-support.md +++ b/docs/valkey-support.md @@ -18,4 +18,23 @@ valkey.set("key", "value") print(valkey.get("key")) ``` +Alternatively, you can start a thread with a Fake Valkey server. + +```python +from threading import Thread +from fakeredis import TcpFakeServer + +server_address = ("127.0.0.1", 6379) +server = TcpFakeServer(server_address, server_type="valkey") +t = Thread(target=server.serve_forever, daemon=True) +t.start() + +import valkey + +r = valkey.Valkey(host=server_address[0], port=server_address[1]) +r.set("foo", "bar") +assert r.get("foo") == b"bar" + +``` + [1]: https://github.com/valkey-io/valkey diff --git a/fakeredis/_basefakesocket.py b/fakeredis/_basefakesocket.py index c0ef4f55..7d8e720c 100644 --- a/fakeredis/_basefakesocket.py +++ b/fakeredis/_basefakesocket.py @@ -251,6 +251,10 @@ def _name_to_func(self, cmd_name: str) -> Tuple[Optional[Callable[[Any], Any]], clean_name = cmd_name.replace("\r", " ").replace("\n", " ") raise SimpleError(msgs.UNKNOWN_COMMAND_MSG.format(clean_name)) sig = SUPPORTED_COMMANDS[cmd_name] + if self._server.server_type not in sig.server_types: + # redis remaps \r or \n in an error to ' ' to make it legal protocol + clean_name = cmd_name.replace("\r", " ").replace("\n", " ") + raise SimpleError(msgs.UNKNOWN_COMMAND_MSG.format(clean_name)) func = getattr(self, sig.func_name, None) return func, sig diff --git a/fakeredis/_commands.py b/fakeredis/_commands.py index aafbf176..f99b90bb 100644 --- a/fakeredis/_commands.py +++ b/fakeredis/_commands.py @@ -425,6 +425,7 @@ def __init__( repeat: Tuple[Type[Union[RedisType, bytes]]] = (), # type:ignore args: Tuple[str] = (), # type:ignore flags: str = "", + server_types: Tuple[str] = ("redis", "valkey", "dragonfly"), # supported server types: redis, dragonfly, valkey ): self.name = name self.func_name = func_name @@ -432,6 +433,7 @@ def __init__( self.repeat = repeat self.flags = set(flags) self.command_args = args + self.server_types: Set[str] = set(server_types) def check_arity(self, args: Sequence[Any], version: Tuple[int]) -> None: if len(args) == len(self.fixed): diff --git a/fakeredis/_fakesocket.py b/fakeredis/_fakesocket.py index 0525a499..47296d7f 100644 --- a/fakeredis/_fakesocket.py +++ b/fakeredis/_fakesocket.py @@ -1,5 +1,20 @@ -from typing import Optional, Set, Any +from typing import Optional, Set +from fakeredis.commands_mixins import ( + BitmapCommandsMixin, + ConnectionCommandsMixin, + GenericCommandsMixin, + GeoCommandsMixin, + HashCommandsMixin, + ListCommandsMixin, + PubSubCommandsMixin, + ScriptingCommandsMixin, + ServerCommandsMixin, + StringCommandsMixin, + TransactionsCommandsMixin, + SetCommandsMixin, + StreamsCommandsMixin, +) from fakeredis.stack import ( JSONCommandsMixin, BFCommandsMixin, @@ -11,30 +26,8 @@ ) from ._basefakesocket import BaseFakeSocket from ._server import FakeServer -from .commands_mixins.bitmap_mixin import BitmapCommandsMixin -from .commands_mixins.connection_mixin import ConnectionCommandsMixin -from .commands_mixins.generic_mixin import GenericCommandsMixin -from .commands_mixins.geo_mixin import GeoCommandsMixin -from .commands_mixins.hash_mixin import HashCommandsMixin -from .commands_mixins.list_mixin import ListCommandsMixin -from .commands_mixins.pubsub_mixin import PubSubCommandsMixin - -try: - from .commands_mixins.scripting_mixin import ScriptingCommandsMixin -except ImportError: - - class ScriptingCommandsMixin: # type: ignore # noqa: E303 - def __init__(self, *args: Any, **kwargs: Any) -> None: - kwargs.pop("lua_modules", None) - super(ScriptingCommandsMixin, self).__init__(*args, **kwargs) # type: ignore - - -from .commands_mixins.server_mixin import ServerCommandsMixin -from .commands_mixins.set_mixin import SetCommandsMixin from .commands_mixins.sortedset_mixin import SortedSetCommandsMixin -from .commands_mixins.streams_mixin import StreamsCommandsMixin -from .commands_mixins.string_mixin import StringCommandsMixin -from .commands_mixins.transactions_mixin import TransactionsCommandsMixin +from .server_specific_commands import DragonflyCommandsMixin class FakeSocket( @@ -60,6 +53,7 @@ class FakeSocket( TopkCommandsMixin, TDigestCommandsMixin, TimeSeriesCommandsMixin, + DragonflyCommandsMixin, ): def __init__( self, diff --git a/fakeredis/_server.py b/fakeredis/_server.py index e856666d..97d200f9 100644 --- a/fakeredis/_server.py +++ b/fakeredis/_server.py @@ -3,7 +3,7 @@ import time import weakref from collections import defaultdict -from typing import Dict, Tuple, Any, List, Optional, Union +from typing import Dict, Tuple, Any, List, Optional, Union, Literal from fakeredis._helpers import Database, FakeSelector @@ -11,6 +11,8 @@ VersionType = Union[Tuple[int, ...], int, str] +ServerType = Literal["redis", "dragonfly", "valkey"] + def _create_version(v: VersionType) -> Tuple[int, ...]: if isinstance(v, tuple): @@ -26,7 +28,7 @@ def _create_version(v: VersionType) -> Tuple[int, ...]: class FakeServer: _servers_map: Dict[str, "FakeServer"] = dict() - def __init__(self, version: VersionType = (7,), server_type: str = "redis") -> None: + def __init__(self, version: VersionType = (7,), server_type: ServerType = "redis") -> None: self.lock = threading.Lock() self.dbs: Dict[int, Database] = defaultdict(lambda: Database(self.lock)) # Maps channel/pattern to a weak set of sockets diff --git a/fakeredis/_tcp_server.py b/fakeredis/_tcp_server.py index 3743b81d..4fe990ae 100644 --- a/fakeredis/_tcp_server.py +++ b/fakeredis/_tcp_server.py @@ -6,6 +6,7 @@ from fakeredis import FakeRedis from fakeredis import FakeServer +from fakeredis._server import ServerType LOGGER = logging.getLogger("fakeredis") @@ -113,7 +114,7 @@ def __init__( self, server_address: Tuple[str | bytes | bytearray, int], bind_and_activate: bool = True, - server_type: str = "redis", + server_type: ServerType = "redis", server_version: Tuple[int, ...] = (7, 4), ): super().__init__(server_address, TCPFakeRequestHandler, bind_and_activate) diff --git a/fakeredis/commands_mixins/__init__.py b/fakeredis/commands_mixins/__init__.py index e69de29b..440c9643 100644 --- a/fakeredis/commands_mixins/__init__.py +++ b/fakeredis/commands_mixins/__init__.py @@ -0,0 +1,41 @@ +from typing import Any + +from .bitmap_mixin import BitmapCommandsMixin +from .connection_mixin import ConnectionCommandsMixin +from .generic_mixin import GenericCommandsMixin +from .geo_mixin import GeoCommandsMixin +from .hash_mixin import HashCommandsMixin +from .list_mixin import ListCommandsMixin +from .pubsub_mixin import PubSubCommandsMixin +from .server_mixin import ServerCommandsMixin +from .set_mixin import SetCommandsMixin +from .streams_mixin import StreamsCommandsMixin +from .string_mixin import StringCommandsMixin + +try: + from .scripting_mixin import ScriptingCommandsMixin +except ImportError: + + class ScriptingCommandsMixin: # type: ignore # noqa: E303 + def __init__(self, *args: Any, **kwargs: Any) -> None: + kwargs.pop("lua_modules", None) + super(ScriptingCommandsMixin, self).__init__(*args, **kwargs) # type: ignore + + +from .transactions_mixin import TransactionsCommandsMixin + +__all__ = [ + "BitmapCommandsMixin", + "ConnectionCommandsMixin", + "GenericCommandsMixin", + "GeoCommandsMixin", + "HashCommandsMixin", + "ListCommandsMixin", + "PubSubCommandsMixin", + "ScriptingCommandsMixin", + "TransactionsCommandsMixin", + "ServerCommandsMixin", + "SetCommandsMixin", + "StreamsCommandsMixin", + "StringCommandsMixin", +] diff --git a/fakeredis/commands_mixins/connection_mixin.py b/fakeredis/commands_mixins/connection_mixin.py index 9506e931..e7056e2b 100644 --- a/fakeredis/commands_mixins/connection_mixin.py +++ b/fakeredis/commands_mixins/connection_mixin.py @@ -32,7 +32,7 @@ def ping(self, *args: bytes) -> Union[List[bytes], bytes, SimpleString]: else: return args[0] if args else PONG - @command((DbIndex,)) + @command(name="SELECT", fixed=(DbIndex,)) def select(self, index: DbIndex) -> SimpleString: self._db = self._server.dbs[index] self._db_num = index # type: ignore diff --git a/fakeredis/commands_mixins/generic_mixin.py b/fakeredis/commands_mixins/generic_mixin.py index b0b761bb..7af9f3e0 100644 --- a/fakeredis/commands_mixins/generic_mixin.py +++ b/fakeredis/commands_mixins/generic_mixin.py @@ -110,11 +110,7 @@ def exists(self, *keys): ret += 1 return ret - @command( - name="expire", - fixed=(Key(), Int), - repeat=(bytes,), - ) + @command(name="EXPIRE", fixed=(Key(), Int), repeat=(bytes,)) def expire(self, key: CommandItem, seconds: int, *args: bytes) -> int: res = self._expireat(key, self._db.time + seconds, *args) return res diff --git a/fakeredis/server_specific_commands/__init__.py b/fakeredis/server_specific_commands/__init__.py new file mode 100644 index 00000000..8df4185c --- /dev/null +++ b/fakeredis/server_specific_commands/__init__.py @@ -0,0 +1,5 @@ +from fakeredis.server_specific_commands.dragonfly_mixin import DragonflyCommandsMixin + +__all__ = [ + "DragonflyCommandsMixin", +] diff --git a/fakeredis/server_specific_commands/dragonfly_mixin.py b/fakeredis/server_specific_commands/dragonfly_mixin.py new file mode 100644 index 00000000..93807488 --- /dev/null +++ b/fakeredis/server_specific_commands/dragonfly_mixin.py @@ -0,0 +1,20 @@ +from typing import Callable + +from fakeredis._commands import command, Key, Int, CommandItem +from fakeredis._helpers import Database + + +class DragonflyCommandsMixin(object): + _expireat: Callable[[CommandItem, int], int] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._db: Database + + @command(name="SADDEX", fixed=(Key(set), Int, bytes), repeat=(bytes,), server_types=("dragonfly",)) + def saddex(self, key: CommandItem, seconds: int, *members: bytes) -> int: + old_size = len(key.value) + key.value.update(members) + key.updated() + self._expireat(key, self._db.time + seconds) + return len(key.value) - old_size diff --git a/pyproject.toml b/pyproject.toml index d6aff562..14e72764 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ packages = [ { include = "fakeredis" }, { include = "LICENSE", to = "fakeredis" }, ] -version = "2.25.2" +version = "2.26.0" description = "Python implementation of redis API, can be used for testing purposes." readme = "README.md" keywords = ["redis", "RedisJson", "RedisBloom", "tests", "redis-stack"] diff --git a/test/test_mixins/test_server_commands.py b/test/test_mixins/test_server_commands.py index 367684e9..c3cd4150 100644 --- a/test/test_mixins/test_server_commands.py +++ b/test/test_mixins/test_server_commands.py @@ -50,13 +50,17 @@ def test_lastsave(r: redis.Redis): @fake_only def test_command(r: redis.Redis): commands_dict = r.command() - one_word_commands = {cmd for cmd in SUPPORTED_COMMANDS if " " not in cmd} - assert one_word_commands - set(commands_dict.keys()) == set() + one_word_commands = {cmd for cmd in SUPPORTED_COMMANDS if " " not in cmd and SUPPORTED_COMMANDS[cmd].server_types} + server_unsupported_commands = one_word_commands - set(commands_dict.keys()) + for command in server_unsupported_commands: + assert "redis" not in SUPPORTED_COMMANDS[command].server_types @fake_only def test_command_count(r: redis.Redis): - assert r.command_count() >= len([cmd for cmd in SUPPORTED_COMMANDS if " " not in cmd]) + assert r.command_count() >= len( + [cmd for (cmd, cmd_info) in SUPPORTED_COMMANDS.items() if " " not in cmd and "redis" in cmd_info.server_types] + ) @pytest.mark.unsupported_server_types("dragonfly")