Skip to content

Commit

Permalink
Implement sorted set commands
Browse files Browse the repository at this point in the history
Fix #187
Fix #188
Fix #189
Fix #194
  • Loading branch information
cunla committed Jun 16, 2023
1 parent 27cd3f8 commit b5ea99c
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 28 deletions.
9 changes: 5 additions & 4 deletions docs/about/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ description: Change log of all fakeredis releases
### 🚀 Features

- Implemented support for various stream groups commands:
- `XGROUP CREATE` #161, `XGROUP DESTROY` #164, `XGROUP SETID` #165, `XGROUP DELCONSUMER` #162,
`XGROUP CREATECONSUMER` #163, `XINFO GROUPS` #168, `XINFO CONSUMERS` #168, `XINFO STREAM` #169, `XREADGROUP` #171,
`XACK` #157, `XPENDING` #170
- Implemented `ZRANDMEMBER` #192
- `XGROUP CREATE` #161, `XGROUP DESTROY` #164, `XGROUP SETID` #165, `XGROUP DELCONSUMER` #162,
`XGROUP CREATECONSUMER` #163, `XINFO GROUPS` #168, `XINFO CONSUMERS` #168, `XINFO STREAM` #169, `XREADGROUP` #171,
`XACK` #157, `XPENDING` #170
- Implemented sorted set commands:
- `ZRANDMEMBER` #192, `ZDIFF` #187, `ZINTER` #189, `ZUNION` #194, `ZDIFFSTORE` #188

### 🧰 Maintenance

Expand Down
32 changes: 16 additions & 16 deletions docs/redis-commands/Redis.md
Original file line number Diff line number Diff line change
Expand Up @@ -771,10 +771,22 @@ Returns the number of members in a sorted set.

Returns the count of members in a sorted set that have scores within a range.

### [ZDIFF](https://redis.io/commands/zdiff/)

Returns the difference between multiple sorted sets.

### [ZDIFFSTORE](https://redis.io/commands/zdiffstore/)

Stores the difference of multiple sorted sets in a key.

### [ZINCRBY](https://redis.io/commands/zincrby/)

Increments the score of a member in a sorted set.

### [ZINTER](https://redis.io/commands/zinter/)

Returns the intersect of multiple sorted sets.

### [ZINTERSTORE](https://redis.io/commands/zinterstore/)

Stores the intersect of multiple sorted sets in a key.
Expand Down Expand Up @@ -855,6 +867,10 @@ Iterates over members and scores of a sorted set.

Returns the score of a member in a sorted set.

### [ZUNION](https://redis.io/commands/zunion/)

Returns the union of multiple sorted sets.

### [ZUNIONSTORE](https://redis.io/commands/zunionstore/)

Stores the union of multiple sorted sets in a key.
Expand All @@ -867,18 +883,6 @@ Stores the union of multiple sorted sets in a key.

Removes and returns a member by score from one or more sorted sets. Blocks until a member is available otherwise. Deletes the sorted set if the last element was popped.

#### [ZDIFF](https://redis.io/commands/zdiff/) <small>(not implemented)</small>

Returns the difference between multiple sorted sets.

#### [ZDIFFSTORE](https://redis.io/commands/zdiffstore/) <small>(not implemented)</small>

Stores the difference of multiple sorted sets in a key.

#### [ZINTER](https://redis.io/commands/zinter/) <small>(not implemented)</small>

Returns the intersect of multiple sorted sets.

#### [ZINTERCARD](https://redis.io/commands/zintercard/) <small>(not implemented)</small>

Returns the number of members of the intersect of multiple sorted sets.
Expand All @@ -891,10 +895,6 @@ Returns the highest- or lowest-scoring members from one or more sorted sets afte

Stores a range of members from sorted set in a key.

#### [ZUNION](https://redis.io/commands/zunion/) <small>(not implemented)</small>

Returns the union of multiple sorted sets.


## generic commands

Expand Down
66 changes: 58 additions & 8 deletions fakeredis/commands_mixins/sortedset_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@
from fakeredis._helpers import (SimpleError, casematch, null_terminate, )
from fakeredis._zset import ZSet

SORTED_SET_METHODS = {
'ZUNIONSTORE': lambda s1, s2: s1 | s2,
'ZUNION': lambda s1, s2: s1 | s2,
'ZINTERSTORE': lambda s1, s2: s1.intersection(s2),
'ZINTER': lambda s1, s2: s1.intersection(s2),
'ZDIFFSTORE': lambda s1, s2: s1 - s2,
'ZDIFF': lambda s1, s2: s1 - s2,
}


class SortedSetCommandsMixin:
# Sorted set commands
Expand Down Expand Up @@ -317,7 +326,7 @@ def _get_zset(value):
else:
raise SimpleError(msgs.WRONGTYPE_MSG)

def _zunioninter(self, func, dest, numkeys, *args):
def _zunioninterdiff(self, func, dest, numkeys, *args):
if numkeys < 1:
raise SimpleError(msgs.ZUNIONSTORE_KEYS_MSG.format(func.lower()))
if numkeys > len(args):
Expand Down Expand Up @@ -348,11 +357,9 @@ def _zunioninter(self, func, dest, numkeys, *args):
sets.append(self._get_zset(item.value))

out_members = set(sets[0])
method = SORTED_SET_METHODS[func]
for s in sets[1:]:
if func == 'ZUNIONSTORE':
out_members |= set(s)
else:
out_members.intersection_update(s)
out_members = method(out_members, set(s))

# We first build a regular dict and turn it into a ZSet. The
# reason is subtle: a ZSet won't update a score from -0 to +0
Expand All @@ -366,7 +373,7 @@ def _zunioninter(self, func, dest, numkeys, *args):
score *= w
# Redis only does this step for ZUNIONSTORE. See
# https://github.com/antirez/redis/issues/3954.
if func == 'ZUNIONSTORE' and math.isnan(score):
if func in {'ZUNIONSTORE', 'ZUNION'} and math.isnan(score):
score = 0.0
if member not in out_members:
continue
Expand All @@ -390,16 +397,59 @@ def _zunioninter(self, func, dest, numkeys, *args):
for member, score in out.items():
out_zset[member] = score

if dest is None:
return out_zset

dest.value = out_zset
return len(out_zset)

@command((Key(), Int, bytes), (bytes,))
def zunionstore(self, dest, numkeys, *args):
return self._zunioninter('ZUNIONSTORE', dest, numkeys, *args)
return self._zunioninterdiff('ZUNIONSTORE', dest, numkeys, *args)

@command((Key(), Int, bytes), (bytes,))
def zinterstore(self, dest, numkeys, *args):
return self._zunioninter('ZINTERSTORE', dest, numkeys, *args)
return self._zunioninterdiff('ZINTERSTORE', dest, numkeys, *args)

@command((Key(), Int, bytes), (bytes,))
def zdiffstore(self, dest, numkeys, *args):
return self._zunioninterdiff('ZDIFFSTORE', dest, numkeys, *args)

@command((Int, bytes,), (bytes,))
def zdiff(self, numkeys, *args):
withscores = casematch(b'withscores', args[-1])
sets = args[:-1] if withscores else args
res = self._zunioninterdiff('ZDIFF', None, numkeys, *sets)

if withscores:
res = [item for t in res for item in (t, Float.encode(res[t], False))]
else:
res = [t for t in res]
return res

@command((Int, bytes,), (bytes,))
def zunion(self, numkeys, *args):
withscores = casematch(b'withscores', args[-1])
sets = args[:-1] if withscores else args
res = self._zunioninterdiff('ZUNION', None, numkeys, *sets)

if withscores:
res = [item for t in res for item in (t, Float.encode(res[t], False))]
else:
res = [t for t in res]
return res

@command((Int, bytes,), (bytes,))
def zinter(self, numkeys, *args):
withscores = casematch(b'withscores', args[-1])
sets = args[:-1] if withscores else args
res = self._zunioninterdiff('ZINTER', None, numkeys, *sets)

if withscores:
res = [item for t in res for item in (t, Float.encode(res[t], False))]
else:
res = [t for t in res]
return res

@command(name="ZMSCORE", fixed=(Key(ZSet), bytes), repeat=(bytes,))
def zmscore(self, key: CommandItem, *members: Union[str, bytes]) -> list[Optional[float]]:
Expand Down
55 changes: 55 additions & 0 deletions test/test_mixins/test_sortedset_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,3 +1091,58 @@ def test_zrandemember(r: redis.Redis):
assert len(r.zrandmember("a", 10)) == 5
# with duplications
assert len(r.zrandmember("a", -10)) == 10


def test_zdiffstore(r: redis.Redis):
r.zadd("a", {"a1": 1, "a2": 2, "a3": 3})
r.zadd("b", {"a1": 1, "a2": 2})
assert r.zdiffstore("out", ["a", "b"])
assert r.zrange("out", 0, -1) == [b"a3"]
assert r.zrange("out", 0, -1, withscores=True) == [(b"a3", 3.0)]


def test_zdiff(r: redis.Redis):
r.zadd("a", {"a1": 1, "a2": 2, "a3": 3})
r.zadd("b", {"a1": 1, "a2": 2})
assert r.zdiff(["a", "b"]) == [b"a3"]
assert r.zdiff(["a", "b"], withscores=True) == [b"a3", b"3"]


def test_zunion(r: redis.Redis):
r.zadd("a", {"a1": 1, "a2": 1, "a3": 1})
r.zadd("b", {"a1": 2, "a2": 2, "a3": 2})
r.zadd("c", {"a1": 6, "a3": 5, "a4": 4})
# sum
assert r.zunion(["a", "b", "c"]) == [b"a2", b"a4", b"a3", b"a1"]
assert r.zunion(["a", "b", "c"], withscores=True) == [
(b"a2", 3), (b"a4", 4), (b"a3", 8), (b"a1", 9), ]
# max
assert r.zunion(["a", "b", "c"], aggregate="MAX", withscores=True) == [
(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6), ]
# min
assert r.zunion(["a", "b", "c"], aggregate="MIN", withscores=True) == [
(b"a1", 1), (b"a2", 1), (b"a3", 1), (b"a4", 4), ]
# with weight
assert r.zunion({"a": 1, "b": 2, "c": 3}, withscores=True) == [
(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23), ]


def test_zinter(r: redis.Redis):
r.zadd("a", {"a1": 1, "a2": 2, "a3": 1})
r.zadd("b", {"a1": 2, "a2": 2, "a3": 2})
r.zadd("c", {"a1": 6, "a3": 5, "a4": 4})
assert r.zinter(["a", "b", "c"]) == [b"a3", b"a1"]
# invalid aggregation
with pytest.raises(redis.DataError):
r.zinter(["a", "b", "c"], aggregate="foo", withscores=True)
# aggregate with SUM
assert r.zinter(["a", "b", "c"], withscores=True) == [(b"a3", 8), (b"a1", 9)]
# aggregate with MAX
assert r.zinter(["a", "b", "c"], aggregate="MAX", withscores=True) == [
(b"a3", 5), (b"a1", 6), ]
# aggregate with MIN
assert r.zinter(["a", "b", "c"], aggregate="MIN", withscores=True) == [
(b"a1", 1), (b"a3", 1), ]
# with weights
assert r.zinter({"a": 1, "b": 2, "c": 3}, withscores=True) == [
(b"a3", 20), (b"a1", 23), ]

0 comments on commit b5ea99c

Please sign in to comment.