Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(zset_family): support WITHSCORE in zrevrank/zrank commands (#3921) #4001

Merged
merged 1 commit into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/core/bptree_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ template <typename T, typename Policy = BPTreePolicy<T>> class BPTree {

bool Delete(KeyT item);

std::optional<uint32_t> GetRank(KeyT item) const;
std::optional<uint32_t> GetRank(KeyT item, bool reverse = false) const;

size_t Height() const {
return height_;
Expand Down Expand Up @@ -222,7 +222,7 @@ template <typename T, typename Policy> bool BPTree<T, Policy>::Delete(KeyT item)
}

template <typename T, typename Policy>
std::optional<uint32_t> BPTree<T, Policy>::GetRank(KeyT item) const {
std::optional<uint32_t> BPTree<T, Policy>::GetRank(KeyT item, bool reverse) const {
if (!root_)
return std::nullopt;

Expand All @@ -231,6 +231,10 @@ std::optional<uint32_t> BPTree<T, Policy>::GetRank(KeyT item) const {
if (!found)
return std::nullopt;

if (reverse) {
return count_ - path.Rank() - 1;
}

return path.Rank();
}

Expand Down
14 changes: 12 additions & 2 deletions src/core/sorted_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,9 @@ optional<unsigned> SortedMap::GetRank(sds ele, bool reverse) const {
if (obj == nullptr)
return std::nullopt;

optional rank = score_tree->GetRank(obj);
optional rank = score_tree->GetRank(obj, reverse);
DCHECK(rank);
return reverse ? score_map->UpperBoundSize() - *rank - 1 : *rank;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why was UpperBoundSize() used here? It isn't obvious, perhaps something slipped away from me.

IMHO: I'd like to keep DCHECK. It's true that score_map->GetRank shouldn't return std::nullopt in this codepath. It can do it in general though. So it wouldn't hurt.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dense_set has the ability to expire items but it does not do it proactively, so its size method may include items that are already expired. We do not use the feature in sorted set though.

return *rank;
}

SortedMap::ScoredArray SortedMap::GetRange(const zrangespec& range, unsigned offset, unsigned limit,
Expand Down Expand Up @@ -783,5 +783,15 @@ bool SortedMap::DefragIfNeeded(float ratio) {
return reallocated;
}

std::optional<SortedMap::RankAndScore> SortedMap::GetRankAndScore(sds ele, bool reverse) const {
ScoreSds obj = score_map->FindObj(ele);
if (obj == nullptr)
return std::nullopt;

optional rank = score_tree->GetRank(obj, reverse);
DCHECK(rank);

return SortedMap::RankAndScore{*rank, GetObjScore(obj)};
}
} // namespace detail
} // namespace dfly
2 changes: 2 additions & 0 deletions src/core/sorted_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class SortedMap {
using ScoredMember = std::pair<std::string, double>;
using ScoredArray = std::vector<ScoredMember>;
using ScoreSds = void*;
using RankAndScore = std::pair<unsigned, double>;

SortedMap(PMR_NS::memory_resource* res);
~SortedMap();
Expand Down Expand Up @@ -72,6 +73,7 @@ class SortedMap {

std::optional<double> GetScore(sds ele) const;
std::optional<unsigned> GetRank(sds ele, bool reverse) const;
std::optional<RankAndScore> GetRankAndScore(sds ele, bool reverse) const;
ScoredArray GetRange(const zrangespec& r, unsigned offs, unsigned len, bool rev) const;
ScoredArray GetLexRange(const zlexrangespec& r, unsigned o, unsigned l, bool rev) const;

Expand Down
80 changes: 62 additions & 18 deletions src/server/zset_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1387,8 +1387,13 @@ OpResult<unsigned> OpRemRange(const OpArgs& op_args, string_view key,
return iv.removed();
}

OpResult<unsigned> OpRank(const OpArgs& op_args, string_view key, string_view member,
bool reverse) {
struct RankResult {
unsigned rank;
double score = 0;
};

OpResult<RankResult> OpRank(const OpArgs& op_args, string_view key, string_view member,
bool reverse, bool with_score) {
auto res_it = op_args.GetDbSlice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET);
if (!res_it)
return res_it.status();
Expand Down Expand Up @@ -1417,18 +1422,34 @@ OpResult<unsigned> OpRank(const OpArgs& op_args, string_view key, string_view me
if (eptr == NULL)
return OpStatus::KEY_NOTFOUND;

if (reverse) {
return lpLength(zl) / 2 - rank;
RankResult res{};
res.rank = reverse ? lpLength(zl) / 2 - rank : rank - 1;
if (with_score) {
res.score = zzlGetScore(sptr);
}
return rank - 1;
return res;
}
DCHECK_EQ(robj_wrapper->encoding(), OBJ_ENCODING_SKIPLIST);
detail::SortedMap* ss = (detail::SortedMap*)robj_wrapper->inner_obj();
std::optional<unsigned> rank = ss->GetRank(WrapSds(member), reverse);
if (!rank)
return OpStatus::KEY_NOTFOUND;

return *rank;
RankResult res{};

if (with_score) {
auto rankAndScore = ss->GetRankAndScore(WrapSds(member), reverse);
if (!rankAndScore) {
return OpStatus::KEY_NOTFOUND;
}
res.rank = rankAndScore->first;
res.score = rankAndScore->second;
} else {
std::optional<unsigned> rank = ss->GetRank(WrapSds(member), reverse);
if (!rank) {
return OpStatus::KEY_NOTFOUND;
}
res.rank = *rank;
}

return res;
}

OpResult<unsigned> OpCount(const OpArgs& op_args, std::string_view key,
Expand Down Expand Up @@ -1979,17 +2000,40 @@ void ZRangeGeneric(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder,
}

void ZRankGeneric(CmdArgList args, bool reverse, Transaction* tx, SinkReplyBuilder* builder) {
string_view key = ArgS(args, 0);
string_view member = ArgS(args, 1);
// send this error exact as redis does, it checks number of arguments first
if (args.size() > 3) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would you like to try replacing this parsing logic with facade::CmdArgParser ?
we usually replace the old parsing code with this as it's more clear this way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns out that redis does this check first. so i left this check in favor to throw the same error messages.

Providing this check by using the parser isn't obvious.
I tried to write a clean code by using parser for that and failed =D

return builder->SendError(WrongNumArgsError(reverse ? "ZREVRANK" : "ZRANK"));
}

facade::CmdArgParser parser(args);

string_view key = parser.Next();
string_view member = parser.Next();
bool with_score = false;

if (parser.HasNext()) {
parser.ExpectTag("WITHSCORE");
with_score = true;
}

if (!parser.Finalize()) {
return builder->SendError(parser.Error()->MakeReply());
}

auto cb = [&](Transaction* t, EngineShard* shard) {
return OpRank(t->GetOpArgs(shard), key, member, reverse);
return OpRank(t->GetOpArgs(shard), key, member, reverse, with_score);
};

OpResult<RankResult> result = tx->ScheduleSingleHopT(std::move(cb));
auto* rb = static_cast<RedisReplyBuilder*>(builder);
OpResult<unsigned> result = tx->ScheduleSingleHopT(std::move(cb));
if (result) {
rb->SendLong(*result);
if (with_score) {
rb->StartArray(2);
rb->SendLong(result->rank);
rb->SendDouble(result->score);
} else {
rb->SendLong(result->rank);
}
} else if (result.status() == OpStatus::KEY_NOTFOUND) {
rb->SendNull();
} else {
Expand Down Expand Up @@ -2340,7 +2384,7 @@ void ZSetFamily::ZRange(CmdArgList args, Transaction* tx, SinkReplyBuilder* buil
}

void ZSetFamily::ZRank(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
ZRankGeneric(std::move(args), false, tx, builder);
ZRankGeneric(args, false, tx, builder);
}

void ZSetFamily::ZRevRange(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
Expand All @@ -2362,7 +2406,7 @@ void ZSetFamily::ZRevRangeByScore(CmdArgList args, Transaction* tx, SinkReplyBui
}

void ZSetFamily::ZRevRank(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
ZRankGeneric(std::move(args), true, tx, builder);
ZRankGeneric(args, true, tx, builder);
}

void ZSetFamily::ZRangeByLex(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
Expand Down Expand Up @@ -3213,7 +3257,7 @@ void ZSetFamily::Register(CommandRegistry* registry) {
<< CI{"ZREM", CO::FAST | CO::WRITE, -3, 1, 1, acl::kZRem}.HFUNC(ZRem)
<< CI{"ZRANGE", CO::READONLY, -4, 1, 1, acl::kZRange}.HFUNC(ZRange)
<< CI{"ZRANDMEMBER", CO::READONLY, -2, 1, 1, acl::kZRandMember}.HFUNC(ZRandMember)
<< CI{"ZRANK", CO::READONLY | CO::FAST, 3, 1, 1, acl::kZRank}.HFUNC(ZRank)
<< CI{"ZRANK", CO::READONLY | CO::FAST, -3, 1, 1, acl::kZRank}.HFUNC(ZRank)
<< CI{"ZRANGEBYLEX", CO::READONLY, -4, 1, 1, acl::kZRangeByLex}.HFUNC(ZRangeByLex)
<< CI{"ZRANGEBYSCORE", CO::READONLY, -4, 1, 1, acl::kZRangeByScore}.HFUNC(ZRangeByScore)
<< CI{"ZRANGESTORE", CO::WRITE | CO::DENYOOM, -5, 1, 2, acl::kZRangeStore}.HFUNC(ZRangeStore)
Expand All @@ -3226,7 +3270,7 @@ void ZSetFamily::Register(CommandRegistry* registry) {
<< CI{"ZREVRANGEBYLEX", CO::READONLY, -4, 1, 1, acl::kZRevRangeByLex}.HFUNC(ZRevRangeByLex)
<< CI{"ZREVRANGEBYSCORE", CO::READONLY, -4, 1, 1, acl::kZRevRangeByScore}.HFUNC(
ZRevRangeByScore)
<< CI{"ZREVRANK", CO::READONLY | CO::FAST, 3, 1, 1, acl::kZRevRank}.HFUNC(ZRevRank)
<< CI{"ZREVRANK", CO::READONLY | CO::FAST, -3, 1, 1, acl::kZRevRank}.HFUNC(ZRevRank)
<< CI{"ZSCAN", CO::READONLY, -3, 1, 1, acl::kZScan}.HFUNC(ZScan)
<< CI{"ZUNION", CO::READONLY | CO::VARIADIC_KEYS, -3, 2, 2, acl::kZUnion}.HFUNC(ZUnion)
<< CI{"ZUNIONSTORE", kStoreMask, -4, 3, 3, acl::kZUnionStore}.HFUNC(ZUnionStore)
Expand Down
30 changes: 30 additions & 0 deletions src/server/zset_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,43 @@ TEST_F(ZSetFamilyTest, ZRangeRank) {
EXPECT_EQ(2, CheckedInt({"zcount", "x", "1.1", "2.1"}));
EXPECT_EQ(1, CheckedInt({"zcount", "x", "(1.1", "2.1"}));
EXPECT_EQ(0, CheckedInt({"zcount", "y", "(1.1", "2.1"}));
}

TEST_F(ZSetFamilyTest, ZRank) {
Run({"zadd", "x", "1.1", "a", "2.1", "b"});
EXPECT_EQ(0, CheckedInt({"zrank", "x", "a"}));
EXPECT_EQ(1, CheckedInt({"zrank", "x", "b"}));
EXPECT_EQ(1, CheckedInt({"zrevrank", "x", "a"}));
EXPECT_EQ(0, CheckedInt({"zrevrank", "x", "b"}));
EXPECT_THAT(Run({"zrevrank", "x", "c"}), ArgType(RespExpr::NIL));
EXPECT_THAT(Run({"zrank", "y", "c"}), ArgType(RespExpr::NIL));
EXPECT_THAT(Run({"zrevrank", "x", "c", "WITHSCORE"}), ArgType(RespExpr::NIL));
EXPECT_THAT(Run({"zrank", "y", "c", "WITHSCORE"}), ArgType(RespExpr::NIL));

auto resp = Run({"zrank", "x", "a", "WITHSCORE"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), "1.1"));

resp = Run({"zrank", "x", "b", "WITHSCORE"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(1), "2.1"));

resp = Run({"zrevrank", "x", "a", "WITHSCORE"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(1), "1.1"));

resp = Run({"zrevrank", "x", "b", "WITHSCORE"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), "2.1"));

resp = Run({"zrank", "x", "a", "WITHSCORES"});
ASSERT_THAT(resp, ErrArg("syntax error"));

resp = Run({"zrank", "x", "a", "WITHSCORES", "42"});
ASSERT_THAT(resp, ErrArg("wrong number of arguments for 'zrank' command"));

resp = Run({"zrevrank", "x", "a", "WITHSCORES", "42"});
ASSERT_THAT(resp, ErrArg("wrong number of arguments for 'zrevrank' command"));
}

TEST_F(ZSetFamilyTest, LargeSet) {
Expand Down
Loading