Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(WIP) Implement hash expiration functions #328

Merged
merged 22 commits into from
Sep 22, 2024
59 changes: 57 additions & 2 deletions fakeredis/_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import re
import sys
import time
from typing import Tuple, Union, Optional, Any, Type, List, Callable, Sequence, Dict, Set
from typing import Iterable, Tuple, Union, Optional, Any, Type, List, Callable, Sequence, Dict, Set

from . import _msgs as msgs
from ._helpers import null_terminate, SimpleError, Database
from ._helpers import null_terminate, SimpleError, Database, HexpireResult

MAX_STRING_SIZE = 512 * 1024 * 1024
SUPPORTED_COMMANDS: Dict[str, "Signature"] = dict() # Dictionary of supported commands name => Signature
Expand Down Expand Up @@ -110,6 +110,61 @@ def __bool__(self) -> bool:
class Hash(dict): # type:ignore
DECODE_ERROR = msgs.INVALID_HASH_MSG
redis_type = b"hash"
_expirations: Dict[bytes, int | float]

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._expirations = {}

def _prune_key_with_expiration(self, key: bytes) -> None:
self.pop(key, None)
self._expirations.pop(key, None)

def _is_expired(self, key: bytes) -> bool:
if self._expirations.get(key, 0) < time.time():
self._prune_key_with_expiration(key)
return True
return False

def _set_expiration(self, key: bytes, when: Union[int, float]) -> HexpireResult:
now = time.time()
if isinstance(when, int):
now = int(now)
if when <= now:
self._prune_key_with_expiration(key)
return HexpireResult.EXPIRED_IMMEDIATELY
self._expirations[key] = when
return HexpireResult.SUCCESS

def _clear_expiration(self, key: bytes) -> bool:
result = self._expirations.pop(key, None)
return result is not None

def _get_expiration(self, key: bytes) -> Union[None, int, float]:
if not self._is_expired(key):
return self._expirations.get(key, None)
return None

def __get__(self, key: bytes) -> Any:
self._is_expired(key)
return super().__get__(key)

def __contains__(self, key: bytes) -> bool:
self._is_expired(key)
return super().__contains__(key)

def __set__(self, key: bytes, value: Any) -> None:
self._expirations.pop(key, None)
super().__set__(key, value)

def keys(self) -> Iterable[bytes]:
return [k for k in super().keys() if not self._is_expired(k)]

def values(self) -> Iterable[Any]:
return [v for k, v in self.items()]

def items(self) -> Iterable[Tuple[bytes, Any]]:
return [(k, v) for k, v in super().items() if not self._is_expired(k)]


class RedisType:
Expand Down
7 changes: 7 additions & 0 deletions fakeredis/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,10 @@ def check_can_read(self, timeout: Optional[float]) -> bool:
@staticmethod
def check_is_ready_for_command(_: Any) -> bool:
return True


class HexpireResult(enum.IntEnum):
NO_SUCH_KEY = -2
CONDITION_UNMET = 0
SUCCESS = 1
EXPIRED_IMMEDIATELY = 2
139 changes: 138 additions & 1 deletion fakeredis/commands_mixins/hash_mixin.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import datetime
import itertools
import math
import random
from typing import Callable, List, Tuple, Any, Optional
import time
from typing import Callable, List, Tuple, Any, Optional, Union

from fakeredis import _msgs as msgs
from fakeredis._command_args_parsing import extract_args
from fakeredis._commands import command, Key, Hash, Int, Float, CommandItem
from fakeredis._helpers import SimpleError, OK, casematch, SimpleString
from fakeredis._helpers import HexpireResult


class HashCommandsMixin:
Expand Down Expand Up @@ -134,3 +138,136 @@ def hrandfield(self, key: CommandItem, *args: bytes) -> Optional[List[bytes]]:
else:
res = [t[0] for t in res]
return res

def _set_key_expiration(
self,
key: CommandItem,
field: bytes,
expiration: Union[datetime.datetime, datetime.timedelta, int, float],
include_ms: bool,
nx: bool = False,
xx: bool = False,
gt: bool = False,
lt: bool = False,
) -> HexpireResult:
hash_val: Hash = key.value
if field not in hash_val:
return HexpireResult.NO_SUCH_KEY
current_expiration = hash_val._get_expiration(field)
if isinstance(expiration, datetime.datetime):
final_expiration = expiration.timestamp()
elif isinstance(expiration, datetime.timedelta):
final_expiration = time.time() + expiration.total_seconds()
else:
final_expiration = time.time() + expiration
if include_ms:
final_expiration = float(final_expiration)
else:
final_expiration = int(final_expiration)
if (
(nx and current_expiration is not None)
or (xx and current_expiration is None)
or (gt and final_expiration <= current_expiration)
or (lt and final_expiration >= current_expiration)
):
return HexpireResult.CONDITION_UNMET
return hash_val._set_expiration(field, final_expiration)

@command(name="HEXPIRE", fixed=(Key, Union[int, datetime.timedelta]), repeat=(bytes,))
def hexpire(
self,
key: CommandItem,
seconds: Union[int, datetime.timedelta],
*fields: bytes,
nx: bool = False,
xx: bool = False,
gt: bool = False,
lt: bool = False,
) -> list[HexpireResult]:
return [self._set_key_expiration(key, field, seconds, False, nx, xx, gt, lt) for field in fields]

@command(name="HPEXPIRE", fixed=(Key, Union[float, datetime.timedelta]), repeat=(bytes,))
def hpexpire(
self,
key: CommandItem,
milliseconds: float,
*fields: bytes,
nx: bool = False,
xx: bool = False,
gt: bool = False,
lt: bool = False,
) -> list[HexpireResult]:
return [self._set_key_expiration(key, field, milliseconds, True, nx, xx, gt, lt) for field in fields]

@command(name="HEXPIREAT", fixed=(Key, Union[int, datetime.datetime]), repeat=(bytes,))
def hexpireat(
self,
key: CommandItem,
unix_time_seconds: Union[int, datetime.datetime],
*fields: bytes,
nx: bool = False,
xx: bool = False,
gt: bool = False,
lt: bool = False,
) -> list[HexpireResult]:
return [self._set_key_expiration(key, field, unix_time_seconds, False, nx, xx, gt, lt) for field in fields]

@command(name="HPEXPIREAT", fixed=(Key, Union[float, datetime.datetime]), repeat=(bytes,))
def hpexpireat(
self,
key: CommandItem,
unix_time_milliseconds: float,
*fields: bytes,
nx: bool = False,
cunla marked this conversation as resolved.
Show resolved Hide resolved
xx: bool = False,
gt: bool = False,
lt: bool = False,
) -> list[HexpireResult]:
return [self._set_key_expiration(key, field, unix_time_milliseconds, True, nx, xx, gt, lt) for field in fields]

@command(name="HPERSIST", fixed=(Key,), repeat=(bytes,))
def hpersist(self, key: CommandItem, *fields: bytes) -> list[int]:
hash_val: Hash = key.value
return [
-2 if field not in hash_val._expirations else (1 if hash_val._clear_expiration(field) else -1)
for field in fields
]

@command(name="HEXPIRETIME", fixed=(Key,), repeat=(bytes,))
def hexpiretime(self, key: CommandItem, *fields: bytes) -> list[int]:
hash_val: Hash = key.value
return [
-2 if field not in hash_val._expirations else int(hash_val._get_expiration(field) or -1) for field in fields
]

@command(name="HPEXPIRETIME", fixed=(Key,), repeat=(bytes,))
def hpexpiretime(self, key: CommandItem, *fields: bytes) -> list[float]:
hash_val: Hash = key.value
return [
-2 if field not in hash_val._expirations else float(hash_val._get_expiration(field) or -1)
for field in fields
]

@command(name="HTTL", fixed=(Key,), repeat=(bytes,))
def httl(self, key: CommandItem, *fields: bytes) -> list[int]:
cunla marked this conversation as resolved.
Show resolved Hide resolved
hash_val: Hash = key.value
return [
(
-2
if field not in hash_val._expirations
else int(hash_val._get_expiration(field) - time.time() if hash_val._get_expiration(field) else -1)
)
for field in fields
]

@command(name="HPTTL", fixed=(Key,), repeat=(bytes,))
def hpttl(self, key: CommandItem, *fields: bytes) -> list[float]:
hash_val: Hash = key.value
return [
(
-2
if field not in hash_val._expirations
else float(hash_val._get_expiration(field) - time.time() if hash_val._get_expiration(field) else -1)
)
for field in fields
]
58 changes: 58 additions & 0 deletions test/test_mixins/test_hash_commands.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import datetime
import time
from typing import Union
from unittest.mock import patch

import pytest
import redis
import redis.client

from test import testtools
from fakeredis._helpers import HexpireResult


def test_hstrlen_missing(r: redis.Redis):
Expand Down Expand Up @@ -312,3 +318,55 @@ def test_hrandfield(r: redis.Redis):

with pytest.raises(redis.ResponseError):
testtools.raw_command(r, "HRANDFIELD", "key", 3, "WITHVALUES", 3)


BASE_TIME = 1000.0


@pytest.mark.parametrize(
"expiration,preset_expiration,nx,xx,gt,lt,expected_result",
[
# No flags
(BASE_TIME + 100, None, False, False, False, False, HexpireResult.SUCCESS),
(datetime.timedelta(seconds=100), None, False, False, False, False, HexpireResult.SUCCESS),
(BASE_TIME + 100, BASE_TIME + 50, False, False, False, False, HexpireResult.SUCCESS),
(datetime.timedelta(seconds=100), BASE_TIME + 50, False, False, False, False, HexpireResult.SUCCESS),
# NX
(BASE_TIME + 100, None, True, False, False, False, HexpireResult.SUCCESS),
(datetime.timedelta(seconds=100), None, True, False, False, False, HexpireResult.SUCCESS),
(BASE_TIME + 100, BASE_TIME + 50, True, False, False, False, HexpireResult.CONDITION_UNMET),
(datetime.timedelta(seconds=100), BASE_TIME + 50, True, False, False, False, HexpireResult.CONDITION_UNMET),
# XX
(BASE_TIME + 100, None, False, True, False, False, HexpireResult.CONDITION_UNMET),
(datetime.timedelta(seconds=100), None, False, True, False, False, HexpireResult.CONDITION_UNMET),
(BASE_TIME + 100, BASE_TIME + 50, False, True, False, False, HexpireResult.SUCCESS),
(datetime.timedelta(seconds=100), BASE_TIME + 50, False, True, False, False, HexpireResult.SUCCESS),
# GT
(BASE_TIME + 100, None, False, False, True, False, HexpireResult.CONDITION_UNMET),
(datetime.timedelta(seconds=100), None, False, False, True, False, HexpireResult.CONDITION_UNMET),
(BASE_TIME + 100, BASE_TIME + 50, False, False, True, False, HexpireResult.SUCCESS),
(datetime.timedelta(seconds=100), BASE_TIME + 50, False, False, True, False, HexpireResult.SUCCESS),
(BASE_TIME + 100, BASE_TIME + 100, False, False, True, False, HexpireResult.CONDITION_UNMET),
(datetime.timedelta(seconds=100), BASE_TIME + 100, False, False, True, False, HexpireResult.CONDITION_UNMET),
(BASE_TIME + 100, BASE_TIME + 200, False, False, True, False, HexpireResult.CONDITION_UNMET),
(datetime.timedelta(seconds=100), BASE_TIME + 200, False, False, True, False, HexpireResult.CONDITION_UNMET),
# LT
(BASE_TIME + 100, None, False, False, False, True, HexpireResult.CONDITION_UNMET),
(datetime.timedelta(seconds=100), None, False, False, False, True, HexpireResult.CONDITION_UNMET),
(BASE_TIME + 100, BASE_TIME + 50, False, False, False, True, HexpireResult.CONDITION_UNMET),
(datetime.timedelta(seconds=100), BASE_TIME + 50, False, False, False, True, HexpireResult.CONDITION_UNMET),
(BASE_TIME + 100, BASE_TIME + 100, False, False, False, True, HexpireResult.CONDITION_UNMET),
(datetime.timedelta(seconds=100), BASE_TIME + 100, False, False, False, True, HexpireResult.CONDITION_UNMET),
(BASE_TIME + 100, BASE_TIME + 200, False, False, False, True, HexpireResult.SUCCESS),
(datetime.timedelta(seconds=100), BASE_TIME + 200, False, False, False, True, HexpireResult.SUCCESS),
]
)
@patch.object(time, "time", BASE_TIME)
def test_hexpire(r: redis.Redis, current_time: float, expiration: Union[int, datetime.timedelta], preset_expiration: Union[float, None], nx: bool, xx: bool, gt: bool, lt: bool, expected_result: int) -> None:
cunla marked this conversation as resolved.
Show resolved Hide resolved
key = "test_hash_commands"
field = "test_hexpire"
r.hset(key, field)
if preset_expiration is not None:
r.hexpire(key, preset_expiration, field)
result = r.hexpire(key, expiration, field, nx=nx, xx=xx, gt=gt, lt=lt)
assert result == expected_result