Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract param handling #113

Merged
merged 20 commits into from
Jan 1, 2023
19 changes: 3 additions & 16 deletions fakeredis/_basefakesocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import redis

from . import _msgs as msgs
from ._command_args_parsing import extract_args
from ._commands import (
Int, Float, SUPPORTED_COMMANDS, COMMANDS_WITH_SUB, key_value_type)
from ._helpers import (
Expand Down Expand Up @@ -266,22 +267,8 @@ def _scan(self, keys, cursor, *args):
returned exactly once.
"""
cursor = int(cursor)
pattern = None
_type = None
count = 10
if len(args) % 2 != 0:
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
for i in range(0, len(args), 2):
if casematch(args[i], b'match'):
pattern = args[i + 1]
elif casematch(args[i], b'count'):
count = Int.decode(args[i + 1])
if count <= 0:
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
elif casematch(args[i], b'type'):
_type = args[i + 1]
else:
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
(pattern, _type, count), _ = extract_args(args, ('*match', '*type', '+count'))
count = 10 if count is None else count

if cursor >= len(keys):
return [0, []]
Expand Down
91 changes: 91 additions & 0 deletions fakeredis/_command_args_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import Tuple, List, Dict

from . import _msgs as msgs
from ._commands import Int
from ._helpers import SimpleError, null_terminate


def _count_params(s: str):
res = 0
while s[res] == '+' or s[res] == '*':
res += 1
return res


def _encode_arg(s: str):
return s[_count_params(s):].encode()


def _default_value(s: str):
ind = _count_params(s)
if ind == 0:
return False
elif ind == 1:
return None
else:
return [None] * ind


def _parse_params(argument_name: str, ind: int, parse_following: int, actual_args: Tuple[bytes, ...]):
if parse_following == 0:
return True
if ind + parse_following >= len(actual_args):
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
temp_res = []
for i in range(parse_following):
curr_arg = actual_args[ind + i + 1]
if argument_name[i] == '+':
curr_arg = Int.decode(curr_arg)
temp_res.append(curr_arg)

if len(temp_res) == 1:
return temp_res[0]
else:
return temp_res


def extract_args(
actual_args: Tuple[bytes, ...],
expected: Tuple[str, ...],
error_on_unexpected: bool = True,
left_from_first_unexpected: bool = True,
) -> Tuple[List, List]:
"""Parse argument values

Extract from actual arguments which arguments exist and their
numerical value.
An argument can have parameters:
- A numerical (Int) parameter is identified with +.
- A non-numerical parameter is identified with a *.
For example: '++limit' will translate as an argument with 2 int parameters.


>>> extract_args((b'nx', b'ex', b'324', b'xx',), ('nx', 'xx', '+ex', 'keepttl'))
[True, True, 324, False], None
"""

results: List = [_default_value(key) for key in expected]
left_args = []
args_info: Dict[bytes, int] = {
_encode_arg(k): (i, _count_params(k))
for (i, k) in enumerate(expected)
}
i = 0
while i < len(actual_args):
found = False
for key in args_info:
if null_terminate(actual_args[i]).lower() == key:
arg_position, parse_following = args_info[key]
results[arg_position] = _parse_params(expected[arg_position], i, parse_following, actual_args)
i += parse_following
found = True
break

if not found:
if error_on_unexpected:
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
if left_from_first_unexpected:
return results, actual_args[i:]
left_args.append(actual_args[i])
i += 1
return results, left_args
13 changes: 5 additions & 8 deletions fakeredis/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,14 @@ class NoResponse:
def null_terminate(s):
# Redis uses C functions on some strings, which means they stop at the
# first NULL.
if b'\0' in s:
return s[:s.find(b'\0')]
return s


def casenorm(s):
return null_terminate(s).lower()
ind = s.find(b'\0')
if ind > -1:
return s[:ind]
return s.lower()


def casematch(a, b):
return casenorm(a) == casenorm(b)
return null_terminate(a) == null_terminate(b)


def encode_command(s):
Expand Down
60 changes: 15 additions & 45 deletions fakeredis/commands_mixins/generic_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from random import random

from fakeredis import _msgs as msgs
from fakeredis._command_args_parsing import extract_args
from fakeredis._commands import (
command, Key, Int, DbIndex, BeforeAny, CommandItem, SortFloat,
delete_keys, key_value_type, )
Expand Down Expand Up @@ -163,14 +164,7 @@ def renamenx(self, key, newkey):

@command((Key(), Int, bytes), (bytes,))
def restore(self, key, ttl, value, *args):
replace = False
i = 0
while i < len(args):
if casematch(args[i], b'replace'):
replace = True
i += 1
else:
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
(replace,), _ = extract_args(args, ('replace',))
if key and not replace:
raise SimpleError(msgs.RESTORE_KEY_EXISTS)
checksum, value = value[:20], value[20:]
Expand All @@ -192,49 +186,25 @@ def scan(self, cursor, *args):

@command((Key(),), (bytes,))
def sort(self, key, *args):
if key.value is not None and not isinstance(key.value, (set, list, ZSet)):
raise SimpleError(msgs.WRONGTYPE_MSG)
(asc, desc, alpha, store, sortby, (limit_start, limit_count)), args = extract_args(
args, ('asc', 'desc', 'alpha', '*store', '*by', '++limit'),
error_on_unexpected=False,
left_from_first_unexpected=False,
)
limit_start = limit_start or 0
limit_count = -1 if limit_count is None else limit_count
dontsort = (sortby is not None and b'*' not in sortby)

i = 0
desc = False
alpha = False
limit_start = 0
limit_count = -1
store = None
sortby = None
dontsort = False
get = []
if key.value is not None:
if not isinstance(key.value, (set, list, ZSet)):
raise SimpleError(msgs.WRONGTYPE_MSG)

while i < len(args):
arg = args[i]
if casematch(arg, b'asc'):
desc = False
elif casematch(arg, b'desc'):
desc = True
elif casematch(arg, b'alpha'):
alpha = True
elif casematch(arg, b'limit') and i + 2 < len(args):
try:
limit_start = Int.decode(args[i + 1])
limit_count = Int.decode(args[i + 2])
except SimpleError:
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
else:
i += 2
elif casematch(arg, b'store') and i + 1 < len(args):
store = args[i + 1]
i += 1
elif casematch(arg, b'by') and i + 1 < len(args):
sortby = args[i + 1]
if b'*' not in sortby:
dontsort = True
i += 1
elif casematch(arg, b'get') and i + 1 < len(args):
if casematch(args[i], b'get') and i + 1 < len(args):
get.append(args[i + 1])
i += 1
i += 2
else:
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
i += 1

# TODO: force sorting if the object is a set and either in Lua or
# storing to a key, to match redis behaviour.
Expand Down
9 changes: 4 additions & 5 deletions fakeredis/commands_mixins/hash_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,10 @@ def hscan(self, key, cursor, *args):
@command((Key(Hash), bytes, bytes), (bytes, bytes))
def hset(self, key, *args):
h = key.value
created = 0
for i in range(0, len(args), 2):
if args[i] not in h:
created += 1
h[args[i]] = args[i + 1]
keys_count = len(h.keys())
h.update(dict(zip(*[iter(args)] * 2))) # https://stackoverflow.com/a/12739974/1056460
created = len(h.keys()) - keys_count

key.updated()
return created

Expand Down
4 changes: 2 additions & 2 deletions fakeredis/commands_mixins/scripting_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from fakeredis import _msgs as msgs
from fakeredis._commands import command, Int
from fakeredis._helpers import SimpleError, SimpleString, casenorm, OK, encode_command
from fakeredis._helpers import SimpleError, SimpleString, null_terminate, OK, encode_command

LOGGER = logging.getLogger('fakeredis')
REDIS_LOG_LEVELS = {
Expand Down Expand Up @@ -210,7 +210,7 @@ def script_exists(self, *args):

@command(name='script flush', fixed=(), repeat=(bytes,), flags=msgs.FLAG_NO_SCRIPT, )
def script_flush(self, *args):
if len(args) > 1 or (len(args) == 1 and casenorm(args[0]) not in {b'sync', b'async'}):
if len(args) > 1 or (len(args) == 1 and null_terminate(args[0]) not in {b'sync', b'async'}):
raise SimpleError(msgs.BAD_SUBCOMMAND_MSG.format('SCRIPT'))
self.script_cache = {}
return OK
Expand Down
10 changes: 4 additions & 6 deletions fakeredis/commands_mixins/server_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,15 @@ def dbsize(self):

@command((), (bytes,))
def flushdb(self, *args):
if args:
if len(args) != 1 or not casematch(args[0], b'async'):
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
if len(args) > 0 and (len(args) != 1 or not casematch(args[0], b'async')):
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
self._db.clear()
return OK

@command((), (bytes,))
def flushall(self, *args):
if args:
if len(args) != 1 or not casematch(args[0], b'async'):
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
if len(args) > 0 and (len(args) != 1 or not casematch(args[0], b'async')):
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
for db in self._server.dbs.values():
db.clear()
# TODO: clear watches and/or pubsub as well?
Expand Down
Loading