Skip to content

Commit

Permalink
feat: add SRANDMEMBER and ZRANDMEMBER (#2148)
Browse files Browse the repository at this point in the history
* feat: add SRANDMEMBER and ZRANDMEMBER

* fix: fix SRANDMEMBER and ZRANDMEMBER commands behaviour

* fix: fix type and remove extra flag
  • Loading branch information
BorysTheDev authored Nov 10, 2023
1 parent 5381746 commit 587660f
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/facade/cmd_arg_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ template <typename T> T CmdArgParser::NextProxy::Int() {

template uint64_t CmdArgParser::NextProxy::Int<uint64_t>();
template int64_t CmdArgParser::NextProxy::Int<int64_t>();
template uint32_t CmdArgParser::NextProxy::Int<uint32_t>();
template int32_t CmdArgParser::NextProxy::Int<int32_t>();

ErrorReply CmdArgParser::ErrorInfo::MakeReply() const {
switch (type) {
Expand Down
63 changes: 63 additions & 0 deletions src/server/set_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ extern "C" {
#include "base/logging.h"
#include "base/stl_util.h"
#include "core/string_set.h"
#include "facade/cmd_arg_parser.h"
#include "server/acl/acl_commands_def.h"
#include "server/command_registry.h"
#include "server/conn_context.h"
Expand Down Expand Up @@ -1344,6 +1345,66 @@ void SMembers(CmdArgList args, ConnectionContext* cntx) {
}
}

void SRandMember(CmdArgList args, ConnectionContext* cntx) {
CmdArgParser parser{args};
string_view key = parser.Next();

bool is_count = parser.HasNext();
int count = is_count ? parser.Next().Int<int>() : 1;

if (parser.HasNext())
return (*cntx)->SendError(WrongNumArgsError("SRANDMEMBER"));

if (auto err = parser.Error(); err)
return (*cntx)->SendError(err->MakeReply());

const unsigned ucount = std::abs(count);

const auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult<StringVec> {
StringVec result;
OpResult<PrimeIterator> find_res = shard->db_slice().Find(t->GetDbContext(), key, OBJ_SET);
if (!find_res) {
return find_res.status();
}

PrimeValue& pv = find_res.value()->second;
if (IsDenseEncoding(pv)) {
StringSet* ss = (StringSet*)pv.RObjPtr();
ss->set_time(MemberTimeSeconds(t->GetDbContext().time_now_ms));
}

container_utils::IterateSet(find_res.value()->second,
[&result, ucount](container_utils::ContainerEntry ce) {
if (result.size() < ucount) {
result.push_back(ce.ToString());
return true;
}
return false;
});
return result;
};

OpResult<StringVec> result = cntx->transaction->ScheduleSingleHopT(cb);

if (result) {
if (count < 0 && !result->empty()) {
for (auto i = result->size(); i < ucount; ++i) {
// we can return duplicate elements, so first is OK
result->push_back(result->front());
}
}
(*cntx)->SendStringArr(*result, RedisReplyBuilder::SET);
} else if (result.status() == OpStatus::KEY_NOTFOUND) {
if (is_count) {
(*cntx)->SendStringArr(StringVec(), RedisReplyBuilder::SET);
} else {
(*cntx)->SendNull();
}
} else {
(*cntx)->SendError(result.status());
}
}

void SInter(CmdArgList args, ConnectionContext* cntx) {
ResultStringVec result_set(shard_set->size(), OpStatus::SKIPPED);

Expand Down Expand Up @@ -1628,6 +1689,7 @@ constexpr uint32_t kSMove = WRITE | SET | FAST;
constexpr uint32_t kSRem = WRITE | SET | FAST;
constexpr uint32_t kSCard = READ | SET | FAST;
constexpr uint32_t kSPop = WRITE | SET | SLOW;
constexpr uint32_t kSRandMember = READ | SET | SLOW;
constexpr uint32_t kSUnion = READ | SET | SLOW;
constexpr uint32_t kSUnionStore = WRITE | SET | SLOW;
constexpr uint32_t kSScan = READ | SET | SLOW;
Expand Down Expand Up @@ -1656,6 +1718,7 @@ void SetFamily::Register(CommandRegistry* registry) {
<< CI{"SREM", CO::WRITE | CO::FAST, -3, 1, 1, 1, acl::kSRem}.HFUNC(SRem)
<< CI{"SCARD", CO::READONLY | CO::FAST, 2, 1, 1, 1, acl::kSCard}.HFUNC(SCard)
<< CI{"SPOP", CO::WRITE | CO::FAST | CO::NO_AUTOJOURNAL, -2, 1, 1, 1, acl::kSPop}.HFUNC(SPop)
<< CI{"SRANDMEMBER", CO::READONLY, -2, 1, 1, 1, acl::kSRandMember}.HFUNC(SRandMember)
<< CI{"SUNION", CO::READONLY, -2, 1, -1, 1, acl::kSUnion}.HFUNC(SUnion)
<< CI{"SUNIONSTORE", CO::WRITE | CO::DENYOOM | CO::NO_AUTOJOURNAL, -3, 1, -1, 1,
acl::kSUnionStore}
Expand Down
33 changes: 33 additions & 0 deletions src/server/set_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,39 @@ TEST_F(SetFamilyTest, SPop) {
EXPECT_THAT(resp.GetVec(), IsSubsetOf({"a", "b", "c"}));
}

TEST_F(SetFamilyTest, SRandMember) {
auto resp = Run({"sadd", "x", "1", "2", "3"});
resp = Run({"SRandMember", "x"});
ASSERT_THAT(resp, ArgType(RespExpr::STRING));
EXPECT_THAT(resp, "1");

resp = Run({"SRandMember", "x", "2"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("1", "2"));

resp = Run({"SRandMember", "x", "0"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
EXPECT_EQ(resp.GetVec().size(), 0);

resp = Run({"SRandMember", "k"});
ASSERT_THAT(resp, ArgType(RespExpr::NIL));

resp = Run({"SRandMember", "k", "2"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
EXPECT_EQ(resp.GetVec().size(), 0);

resp = Run({"SRandMember", "x", "-5"});
ASSERT_THAT(resp, ArrLen(5));
EXPECT_THAT(resp.GetVec(), ElementsAre("1", "2", "3", "1", "1"));

resp = Run({"SRandMember", "x", "5"});
ASSERT_THAT(resp, ArrLen(3));
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("1", "2", "3"));

resp = Run({"SRandMember", "x", "5", "3"});
EXPECT_THAT(resp, ErrArg("wrong number of arguments"));
}

TEST_F(SetFamilyTest, SMIsMember) {
Run({"sadd", "foo", "a"});
Run({"sadd", "foo", "b"});
Expand Down
52 changes: 52 additions & 0 deletions src/server/zset_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ extern "C" {
#include "base/logging.h"
#include "base/stl_util.h"
#include "core/sorted_map.h"
#include "facade/cmd_arg_parser.h"
#include "facade/error.h"
#include "server/blocking_controller.h"
#include "server/command_registry.h"
Expand Down Expand Up @@ -2279,6 +2280,55 @@ void ZSetFamily::ZRem(CmdArgList args, ConnectionContext* cntx) {
}
}

void ZSetFamily::ZRandMember(CmdArgList args, ConnectionContext* cntx) {
if (args.size() > 3)
return (*cntx)->SendError(WrongNumArgsError("ZRANDMEMBER"));

ZRangeSpec range_spec;
range_spec.interval = IndexInterval(0, -1);

CmdArgParser parser{args};
string_view key = parser.Next();

bool is_count = parser.HasNext();
int count = is_count ? parser.Next().Int<int>() : 1;

range_spec.params.with_scores = static_cast<bool>(parser.Check("WITHSCORES").IgnoreCase());

if (parser.HasNext())
return (*cntx)->SendError(absl::StrCat("Unsupported option:", string_view(parser.Next())));

if (auto err = parser.Error(); err)
return (*cntx)->SendError(err->MakeReply());

bool sign = count < 0;
range_spec.params.limit = std::abs(count);

const auto cb = [&](Transaction* t, EngineShard* shard) {
return OpRange(range_spec, t->GetOpArgs(shard), key);
};

OpResult<ScoredArray> result = cntx->transaction->ScheduleSingleHopT(cb);

if (result) {
if (sign && !result->empty()) {
for (auto i = result->size(); i < range_spec.params.limit; ++i) {
// we can return duplicate elements, so first is OK
result->push_back(result->front());
}
}
(*cntx)->SendScoredArray(result.value(), range_spec.params.with_scores);
} else if (result.status() == OpStatus::KEY_NOTFOUND) {
if (is_count) {
(*cntx)->SendScoredArray(ScoredArray(), range_spec.params.with_scores);
} else {
(*cntx)->SendNull();
}
} else {
(*cntx)->SendError(result.status());
}
}

void ZSetFamily::ZScore(CmdArgList args, ConnectionContext* cntx) {
string_view key = ArgS(args, 0);
string_view member = ArgS(args, 1);
Expand Down Expand Up @@ -3022,6 +3072,7 @@ constexpr uint32_t kZPopMax = WRITE | SORTEDSET | FAST;
constexpr uint32_t kZPopMin = WRITE | SORTEDSET | FAST;
constexpr uint32_t kZRem = WRITE | SORTEDSET | FAST;
constexpr uint32_t kZRange = READ | SORTEDSET | SLOW;
constexpr uint32_t kZRandMember = READ | SORTEDSET | SLOW;
constexpr uint32_t kZRank = READ | SORTEDSET | FAST;
constexpr uint32_t kZRangeByLex = READ | SORTEDSET | SLOW;
constexpr uint32_t kZRangeByScore = READ | SORTEDSET | SLOW;
Expand Down Expand Up @@ -3079,6 +3130,7 @@ void ZSetFamily::Register(CommandRegistry* registry) {
<< CI{"ZPOPMIN", CO::FAST | CO::WRITE, -2, 1, 1, 1, acl::kZPopMin}.HFUNC(ZPopMin)
<< CI{"ZREM", CO::FAST | CO::WRITE, -3, 1, 1, 1, acl::kZRem}.HFUNC(ZRem)
<< CI{"ZRANGE", CO::READONLY, -4, 1, 1, 1, acl::kZRange}.HFUNC(ZRange)
<< CI{"ZRANDMEMBER", CO::READONLY, -2, 1, 1, 1, acl::kZRandMember}.HFUNC(ZRandMember)
<< CI{"ZRANK", CO::READONLY | CO::FAST, 3, 1, 1, 1, acl::kZRange}.HFUNC(ZRank)
<< CI{"ZRANGEBYLEX", CO::READONLY, -4, 1, 1, 1, acl::kZRangeByLex}.HFUNC(ZRangeByLex)
<< CI{"ZRANGEBYSCORE", CO::READONLY, -4, 1, 1, 1, acl::kZRangeByScore}.HFUNC(ZRangeByScore)
Expand Down
1 change: 1 addition & 0 deletions src/server/zset_family.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class ZSetFamily {
static void ZRange(CmdArgList args, ConnectionContext* cntx);
static void ZRank(CmdArgList args, ConnectionContext* cntx);
static void ZRem(CmdArgList args, ConnectionContext* cntx);
static void ZRandMember(CmdArgList args, ConnectionContext* cntx);
static void ZScore(CmdArgList args, ConnectionContext* cntx);
static void ZMScore(CmdArgList args, ConnectionContext* cntx);
static void ZRangeByLex(CmdArgList args, ConnectionContext* cntx);
Expand Down
50 changes: 50 additions & 0 deletions src/server/zset_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,56 @@ TEST_F(ZSetFamilyTest, ZRem) {
EXPECT_THAT(Run({"zrange", "x", "(-inf", "(+inf", "byscore"}), "a");
}

TEST_F(ZSetFamilyTest, ZRandMember) {
auto resp = Run({
"zadd",
"x",
"1",
"a",
"2",
"b",
"3",
"c",
});
resp = Run({"ZRandMember", "x"});
ASSERT_THAT(resp, ArgType(RespExpr::STRING));
EXPECT_THAT(resp, "a");

resp = Run({"ZRandMember", "x", "2"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "b"));

resp = Run({"ZRandMember", "x", "0"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
EXPECT_EQ(resp.GetVec().size(), 0);

resp = Run({"ZRandMember", "k"});
ASSERT_THAT(resp, ArgType(RespExpr::NIL));

resp = Run({"ZRandMember", "k", "2"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
EXPECT_EQ(resp.GetVec().size(), 0);

resp = Run({"ZRandMember", "x", "-5"});
ASSERT_THAT(resp, ArrLen(5));
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "b", "c", "a", "a"));

resp = Run({"ZRandMember", "x", "5"});
ASSERT_THAT(resp, ArrLen(3));
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "b", "c"));

resp = Run({"ZRandMember", "x", "-5", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(10));
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "1", "b", "2", "c", "3", "a", "1", "a", "1"));

resp = Run({"ZRandMember", "x", "3", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(6));
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "1", "b", "2", "c", "3"));

resp = Run({"ZRandMember", "x", "3", "WITHSCORES", "test"});
EXPECT_THAT(resp, ErrArg("wrong number of arguments"));
}

TEST_F(ZSetFamilyTest, ZMScore) {
Run({"zadd", "zms", "3.14", "a"});
Run({"zadd", "zms", "42", "another"});
Expand Down

0 comments on commit 587660f

Please sign in to comment.