From 12b035021861365d5e71254de9ec0a859866fdd5 Mon Sep 17 00:00:00 2001 From: Daria Sukhonina Date: Fri, 20 Jan 2023 18:39:55 +0300 Subject: [PATCH 1/9] feat(string family): implement cl.throttle --- src/server/string_family.cc | 149 ++++++++++++++++++++++++++++++- src/server/string_family.h | 2 + src/server/string_family_test.cc | 63 +++++++++++++ 3 files changed, 213 insertions(+), 1 deletion(-) diff --git a/src/server/string_family.cc b/src/server/string_family.cc index 8dc8c3b6a446..4340a9b026df 100644 --- a/src/server/string_family.cc +++ b/src/server/string_family.cc @@ -11,7 +11,10 @@ extern "C" { #include #include +#include +#include #include +#include #include "base/logging.h" #include "redis/util.h" @@ -356,6 +359,92 @@ OpResult SetGeneric(ConnectionContext* cntx, const SetCmd::SetParams& spar return cntx->transaction->ScheduleSingleHop(std::move(cb)); } +OpResult> OpThrottle(const OpArgs& op_args, string_view key, int64_t max_burst, + int64_t count, int64_t period_s, int64_t quantity) { + using namespace chrono_literals; + + auto& db_slice = op_args.shard->db_slice(); + + const int64_t limit = max_burst + 1; + const int64_t emission_interval_ms = period_s * 1000 / count; + const int64_t delay_variation_tolerance_ms = emission_interval_ms * limit; + + if (emission_interval_ms == 0) { + return OpStatus::OUT_OF_RANGE; + } + + int64_t remaining = 0; + int64_t reset_after_ms = -1000; + int64_t retry_after_ms = -1000; + + const int64_t increment_ms = emission_interval_ms * quantity; + + auto [it, expire_it] = db_slice.FindExt(op_args.db_cntx, key); + const int64_t now_ms = op_args.db_cntx.time_now_ms; + + int64_t tat_ms = now_ms; + if (IsValid(it)) { + if (it->second.ObjType() != OBJ_STRING) { + return OpStatus::WRONG_TYPE; + } + + auto opt_prev = it->second.TryGetInt(); + if (!opt_prev) { + return OpStatus::INVALID_VALUE; + } + tat_ms = *opt_prev; + } + const int64_t new_tat_ms = max(tat_ms, now_ms) + increment_ms; + + const int64_t allow_at_ms = new_tat_ms - delay_variation_tolerance_ms; + const int64_t diff_ms = now_ms - allow_at_ms; + + const bool limited = diff_ms < 0; + int64_t ttl_ms; + if (limited) { + if (increment_ms <= delay_variation_tolerance_ms) { + retry_after_ms = -diff_ms; + } + ttl_ms = tat_ms - now_ms; + } else { + ttl_ms = new_tat_ms - now_ms; + + if (IsValid(it)) { + db_slice.PreUpdate(op_args.db_cntx.db_index, it); + it->second.SetInt(new_tat_ms); + db_slice.PostUpdate(op_args.db_cntx.db_index, it, key); + } else { + CompactObj cobj; + cobj.SetInt(new_tat_ms); + + // AddNew calls PostUpdate inside. + try { + it = db_slice.AddNew(op_args.db_cntx, key, std::move(cobj), 0); + } catch (bad_alloc&) { + return OpStatus::OUT_OF_MEMORY; + } + } + } + + const int64_t next_ms = delay_variation_tolerance_ms - ttl_ms; + if (next_ms > -emission_interval_ms) { + remaining = next_ms / emission_interval_ms; + } + reset_after_ms = ttl_ms; + + int64_t retry_after_s = retry_after_ms / 1000; + if (retry_after_ms > 0) { + retry_after_s += 1; + } + + int64_t reset_after_s = reset_after_ms / 1000; + if (reset_after_ms > 0) { + reset_after_s += 1; + } + + return array{limited ? 1 : 0, limit, remaining, retry_after_s, reset_after_s}; +} + } // namespace OpStatus SetCmd::Set(const SetParams& params, string_view key, string_view value) { @@ -1170,6 +1259,63 @@ auto StringFamily::OpMGet(bool fetch_mcflag, bool fetch_mcver, const Transaction return response; } +void StringFamily::ClThrottle(CmdArgList args, ConnectionContext* cntx) { + string_view key = ArgS(args, 1); + + int64_t max_burst; + string_view max_burst_str = ArgS(args, 2); + if (!absl::SimpleAtoi(max_burst_str, &max_burst)) { + return (*cntx)->SendError(kInvalidIntErr); + } + + int64_t count; + string_view count_str = ArgS(args, 3); + if (!absl::SimpleAtoi(count_str, &count)) { + return (*cntx)->SendError(kInvalidIntErr); + } + + int64_t period; + string_view period_str = ArgS(args, 4); + if (!absl::SimpleAtoi(period_str, &period)) { + return (*cntx)->SendError(kInvalidIntErr); + } + + int64_t quantity = 1; + if (args.size() > 5) { + string_view quantity_str = ArgS(args, 5); + + if (!absl::SimpleAtoi(quantity_str, &quantity)) { + return (*cntx)->SendError(kInvalidIntErr); + } + } + + auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult> { + return OpThrottle(t->GetOpArgs(shard), key, max_burst, count, period, quantity); + }; + + Transaction* trans = cntx->transaction; + OpResult> result = trans->ScheduleSingleHopT(std::move(cb)); + + switch (result.status()) { + case OpStatus::WRONG_TYPE: + (*cntx)->SendError(result.status()); + break; + case OpStatus::INVALID_VALUE: + (*cntx)->SendError(kInvalidIntErr); + break; + case OpStatus::OUT_OF_RANGE: + (*cntx)->SendError(kIncrOverflow); + break; + default: + (*cntx)->StartArray(result->size()); + const auto& array = result.value(); + for (const auto& v : array) { + (*cntx)->SendLong(v); + } + break; + } +} + void StringFamily::Init(util::ProactorPool* pp) { set_qps.Init(pp); get_qps.Init(pp); @@ -1206,7 +1352,8 @@ void StringFamily::Register(CommandRegistry* registry) { << CI{"STRLEN", CO::READONLY | CO::FAST, 2, 1, 1, 1}.HFUNC(StrLen) << CI{"GETRANGE", CO::READONLY | CO::FAST, 4, 1, 1, 1}.HFUNC(GetRange) << CI{"SUBSTR", CO::READONLY | CO::FAST, 4, 1, 1, 1}.HFUNC(GetRange) // Alias for GetRange - << CI{"SETRANGE", CO::WRITE | CO::FAST | CO::DENYOOM, 4, 1, 1, 1}.HFUNC(SetRange); + << CI{"SETRANGE", CO::WRITE | CO::FAST | CO::DENYOOM, 4, 1, 1, 1}.HFUNC(SetRange) + << CI{"CL.THROTTLE", CO::WRITE | CO::DENYOOM | CO::FAST, -5, 1, 1, 1}.HFUNC(ClThrottle); } } // namespace dfly diff --git a/src/server/string_family.h b/src/server/string_family.h index 20aa2b1d5ed4..ec5e5b22249b 100644 --- a/src/server/string_family.h +++ b/src/server/string_family.h @@ -85,6 +85,8 @@ class StringFamily { static void Prepend(CmdArgList args, ConnectionContext* cntx); static void PSetEx(CmdArgList args, ConnectionContext* cntx); + static void ClThrottle(CmdArgList args, ConnectionContext* cntx); + // These functions are used internally, they do not implement any specific command static void IncrByGeneric(std::string_view key, int64_t val, ConnectionContext* cntx); static void ExtendGeneric(CmdArgList args, bool prepend, ConnectionContext* cntx); diff --git a/src/server/string_family_test.cc b/src/server/string_family_test.cc index f914b8ca3dfa..fa392800be93 100644 --- a/src/server/string_family_test.cc +++ b/src/server/string_family_test.cc @@ -557,4 +557,67 @@ TEST_F(StringFamilyTest, GetEx) { EXPECT_THAT(Run({"getex", "foo"}), ArgType(RespExpr::NIL)); } +TEST_F(StringFamilyTest, ClThrottle) { + const int64_t limit = 5; + const char* const key = "foo"; + const char* const max_burst = "4"; // limit - 1 + const char* const count = "1"; + const char* const period = "10"; + + // You can never make a request larger than the maximum. + auto resp = Run({"cl.throttle", key, max_burst, count, period, "6"}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(1), IntArg(limit), IntArg(5), IntArg(-1), IntArg(0))); + + // Rate limit normal requests appropriately. + resp = Run({"cl.throttle", key, max_burst, count, period}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11))); + + resp = Run({"cl.throttle", key, max_burst, count, period}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(3), IntArg(-1), IntArg(21))); + + resp = Run({"cl.throttle", key, max_burst, count, period}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(2), IntArg(-1), IntArg(31))); + + resp = Run({"cl.throttle", key, max_burst, count, period}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(1), IntArg(-1), IntArg(41))); + + resp = Run({"cl.throttle", key, max_burst, count, period}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(0), IntArg(-1), IntArg(51))); + + resp = Run({"cl.throttle", key, max_burst, count, period}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(1), IntArg(limit), IntArg(0), IntArg(11), IntArg(51))); + + AdvanceTime(30000); + resp = Run({"cl.throttle", key, max_burst, count, period, "1"}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(2), IntArg(-1), IntArg(31))); + + AdvanceTime(1000); + resp = Run({"cl.throttle", key, max_burst, count, period, "1"}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(1), IntArg(-1), IntArg(40))); + + AdvanceTime(9000); + resp = Run({"cl.throttle", key, max_burst, count, period, "1"}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(1), IntArg(-1), IntArg(41))); + + AdvanceTime(40000); + resp = Run({"cl.throttle", key, max_burst, count, period, "1"}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11))); + + AdvanceTime(15000); + resp = Run({"cl.throttle", key, max_burst, count, period, "1"}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11))); + + // Zero-volume request just peeks at the state. + resp = Run({"cl.throttle", key, max_burst, count, period, "0"}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11))); + + // High-volume request uses up more of the limit. + resp = Run({"cl.throttle", key, max_burst, count, period, "2"}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(2), IntArg(-1), IntArg(31))); + + // Large requests cannot exceed limits + resp = Run({"cl.throttle", key, max_burst, count, period, "5"}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(1), IntArg(limit), IntArg(2), IntArg(31), IntArg(31))); +} + } // namespace dfly From 0131afaa2a90f2a9bf9977b27c829d5795559310 Mon Sep 17 00:00:00 2001 From: Daria Sukhonina Date: Mon, 23 Jan 2023 15:17:49 +0300 Subject: [PATCH 2/9] add ttl to cl.throttle --- src/server/string_family.cc | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/server/string_family.cc b/src/server/string_family.cc index 4340a9b026df..ab0fda558249 100644 --- a/src/server/string_family.cc +++ b/src/server/string_family.cc @@ -379,7 +379,7 @@ OpResult> OpThrottle(const OpArgs& op_args, string_view key, i const int64_t increment_ms = emission_interval_ms * quantity; - auto [it, expire_it] = db_slice.FindExt(op_args.db_cntx, key); + auto [it, e_it] = db_slice.FindExt(op_args.db_cntx, key); const int64_t now_ms = op_args.db_cntx.time_now_ms; int64_t tat_ms = now_ms; @@ -410,6 +410,12 @@ OpResult> OpThrottle(const OpArgs& op_args, string_view key, i ttl_ms = new_tat_ms - now_ms; if (IsValid(it)) { + if (IsValid(e_it)) { + e_it->second = db_slice.FromAbsoluteTime(new_tat_ms); + } else { + db_slice.AddExpire(op_args.db_cntx.db_index, it, new_tat_ms); + } + db_slice.PreUpdate(op_args.db_cntx.db_index, it); it->second.SetInt(new_tat_ms); db_slice.PostUpdate(op_args.db_cntx.db_index, it, key); @@ -419,7 +425,7 @@ OpResult> OpThrottle(const OpArgs& op_args, string_view key, i // AddNew calls PostUpdate inside. try { - it = db_slice.AddNew(op_args.db_cntx, key, std::move(cobj), 0); + it = db_slice.AddNew(op_args.db_cntx, key, std::move(cobj), new_tat_ms); } catch (bad_alloc&) { return OpStatus::OUT_OF_MEMORY; } From 4d3c924404911b6d8b1febee3a75368abcf834c3 Mon Sep 17 00:00:00 2001 From: Daria Sukhonina Date: Mon, 23 Jan 2023 17:11:17 +0300 Subject: [PATCH 3/9] apply cl.throttle implementation suggestions --- src/server/string_family.cc | 51 ++++++++++++++++++++----------------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/src/server/string_family.cc b/src/server/string_family.cc index ab0fda558249..1a243c04d403 100644 --- a/src/server/string_family.cc +++ b/src/server/string_family.cc @@ -359,20 +359,13 @@ OpResult SetGeneric(ConnectionContext* cntx, const SetCmd::SetParams& spar return cntx->transaction->ScheduleSingleHop(std::move(cb)); } -OpResult> OpThrottle(const OpArgs& op_args, string_view key, int64_t max_burst, - int64_t count, int64_t period_s, int64_t quantity) { - using namespace chrono_literals; - +OpResult> OpThrottle(const OpArgs& op_args, const string_view key, + const uint64_t limit, const int64_t emission_interval_ms, + const uint64_t quantity) { auto& db_slice = op_args.shard->db_slice(); - const int64_t limit = max_burst + 1; - const int64_t emission_interval_ms = period_s * 1000 / count; const int64_t delay_variation_tolerance_ms = emission_interval_ms * limit; - if (emission_interval_ms == 0) { - return OpStatus::OUT_OF_RANGE; - } - int64_t remaining = 0; int64_t reset_after_ms = -1000; int64_t retry_after_ms = -1000; @@ -1265,38 +1258,50 @@ auto StringFamily::OpMGet(bool fetch_mcflag, bool fetch_mcver, const Transaction return response; } +/* CL.THROTTLE [] */ void StringFamily::ClThrottle(CmdArgList args, ConnectionContext* cntx) { - string_view key = ArgS(args, 1); + const string_view key = ArgS(args, 1); - int64_t max_burst; - string_view max_burst_str = ArgS(args, 2); + // Allow max burst in number of tokens + uint64_t max_burst; + const string_view max_burst_str = ArgS(args, 2); if (!absl::SimpleAtoi(max_burst_str, &max_burst)) { return (*cntx)->SendError(kInvalidIntErr); } - int64_t count; - string_view count_str = ArgS(args, 3); + // Emit count of tokens per period + uint64_t count; + const string_view count_str = ArgS(args, 3); if (!absl::SimpleAtoi(count_str, &count)) { return (*cntx)->SendError(kInvalidIntErr); } - int64_t period; - string_view period_str = ArgS(args, 4); + // Period of emitting count of tockens + uint64_t period; + const string_view period_str = ArgS(args, 4); if (!absl::SimpleAtoi(period_str, &period)) { return (*cntx)->SendError(kInvalidIntErr); } - int64_t quantity = 1; + // Apply quantity of tokens now + uint64_t quantity = 1; if (args.size() > 5) { - string_view quantity_str = ArgS(args, 5); + const string_view quantity_str = ArgS(args, 5); if (!absl::SimpleAtoi(quantity_str, &quantity)) { return (*cntx)->SendError(kInvalidIntErr); } } + const uint64_t limit = max_burst + 1; + const int64_t emission_interval_ms = period * 1000 / count; + + if (emission_interval_ms == 0) { + return (*cntx)->SendError("Zero rates are not supported"); + } + auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult> { - return OpThrottle(t->GetOpArgs(shard), key, max_burst, count, period, quantity); + return OpThrottle(t->GetOpArgs(shard), key, limit, emission_interval_ms, quantity); }; Transaction* trans = cntx->transaction; @@ -1304,13 +1309,13 @@ void StringFamily::ClThrottle(CmdArgList args, ConnectionContext* cntx) { switch (result.status()) { case OpStatus::WRONG_TYPE: - (*cntx)->SendError(result.status()); + (*cntx)->SendError(kWrongTypeErr); break; case OpStatus::INVALID_VALUE: (*cntx)->SendError(kInvalidIntErr); break; - case OpStatus::OUT_OF_RANGE: - (*cntx)->SendError(kIncrOverflow); + case OpStatus::OUT_OF_MEMORY: + (*cntx)->SendError(kOutOfMemory); break; default: (*cntx)->StartArray(result->size()); From a0b9eb40eb13034b4c97afef35f8c2727a65a081 Mon Sep 17 00:00:00 2001 From: Daria Sukhonina Date: Mon, 23 Jan 2023 17:12:55 +0300 Subject: [PATCH 4/9] fix typo in cl.throttle impl --- src/server/string_family.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/server/string_family.cc b/src/server/string_family.cc index 1a243c04d403..8811526b2697 100644 --- a/src/server/string_family.cc +++ b/src/server/string_family.cc @@ -1276,7 +1276,7 @@ void StringFamily::ClThrottle(CmdArgList args, ConnectionContext* cntx) { return (*cntx)->SendError(kInvalidIntErr); } - // Period of emitting count of tockens + // Period of emitting count of tokens uint64_t period; const string_view period_str = ArgS(args, 4); if (!absl::SimpleAtoi(period_str, &period)) { From f6e11b6efcde88b62135727d0c1219f9374954ca Mon Sep 17 00:00:00 2001 From: Daria Sukhonina Date: Tue, 24 Jan 2023 12:29:54 +0300 Subject: [PATCH 5/9] explicit result check in cl.throttle --- src/server/string_family.cc | 38 ++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/src/server/string_family.cc b/src/server/string_family.cc index 8811526b2697..cbe67392bfbf 100644 --- a/src/server/string_family.cc +++ b/src/server/string_family.cc @@ -1307,23 +1307,27 @@ void StringFamily::ClThrottle(CmdArgList args, ConnectionContext* cntx) { Transaction* trans = cntx->transaction; OpResult> result = trans->ScheduleSingleHopT(std::move(cb)); - switch (result.status()) { - case OpStatus::WRONG_TYPE: - (*cntx)->SendError(kWrongTypeErr); - break; - case OpStatus::INVALID_VALUE: - (*cntx)->SendError(kInvalidIntErr); - break; - case OpStatus::OUT_OF_MEMORY: - (*cntx)->SendError(kOutOfMemory); - break; - default: - (*cntx)->StartArray(result->size()); - const auto& array = result.value(); - for (const auto& v : array) { - (*cntx)->SendLong(v); - } - break; + if (result) { + (*cntx)->StartArray(result->size()); + const auto& array = result.value(); + for (const auto& v : array) { + (*cntx)->SendLong(v); + } + } else { + switch (result.status()) { + case OpStatus::WRONG_TYPE: + (*cntx)->SendError(kWrongTypeErr); + break; + case OpStatus::INVALID_VALUE: + (*cntx)->SendError(kInvalidIntErr); + break; + case OpStatus::OUT_OF_MEMORY: + (*cntx)->SendError(kOutOfMemory); + break; + default: + (*cntx)->SendError(result.status()); + break; + } } } From ea66cfb8867c1ddbf47aa9304daf75b8a99932e8 Mon Sep 17 00:00:00 2001 From: Daria Sukhonina Date: Tue, 24 Jan 2023 12:32:18 +0300 Subject: [PATCH 6/9] format string_family_test.cc --- src/server/string_family_test.cc | 47 +++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/src/server/string_family_test.cc b/src/server/string_family_test.cc index fa392800be93..fabe485d7871 100644 --- a/src/server/string_family_test.cc +++ b/src/server/string_family_test.cc @@ -560,64 +560,79 @@ TEST_F(StringFamilyTest, GetEx) { TEST_F(StringFamilyTest, ClThrottle) { const int64_t limit = 5; const char* const key = "foo"; - const char* const max_burst = "4"; // limit - 1 + const char* const max_burst = "4"; // limit - 1 const char* const count = "1"; const char* const period = "10"; // You can never make a request larger than the maximum. auto resp = Run({"cl.throttle", key, max_burst, count, period, "6"}); - ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(1), IntArg(limit), IntArg(5), IntArg(-1), IntArg(0))); + ASSERT_THAT(resp.GetVec(), + ElementsAre(IntArg(1), IntArg(limit), IntArg(5), IntArg(-1), IntArg(0))); // Rate limit normal requests appropriately. resp = Run({"cl.throttle", key, max_burst, count, period}); - ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11))); + ASSERT_THAT(resp.GetVec(), + ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11))); resp = Run({"cl.throttle", key, max_burst, count, period}); - ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(3), IntArg(-1), IntArg(21))); + ASSERT_THAT(resp.GetVec(), + ElementsAre(IntArg(0), IntArg(limit), IntArg(3), IntArg(-1), IntArg(21))); resp = Run({"cl.throttle", key, max_burst, count, period}); - ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(2), IntArg(-1), IntArg(31))); + ASSERT_THAT(resp.GetVec(), + ElementsAre(IntArg(0), IntArg(limit), IntArg(2), IntArg(-1), IntArg(31))); resp = Run({"cl.throttle", key, max_burst, count, period}); - ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(1), IntArg(-1), IntArg(41))); + ASSERT_THAT(resp.GetVec(), + ElementsAre(IntArg(0), IntArg(limit), IntArg(1), IntArg(-1), IntArg(41))); resp = Run({"cl.throttle", key, max_burst, count, period}); - ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(0), IntArg(-1), IntArg(51))); + ASSERT_THAT(resp.GetVec(), + ElementsAre(IntArg(0), IntArg(limit), IntArg(0), IntArg(-1), IntArg(51))); resp = Run({"cl.throttle", key, max_burst, count, period}); - ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(1), IntArg(limit), IntArg(0), IntArg(11), IntArg(51))); + ASSERT_THAT(resp.GetVec(), + ElementsAre(IntArg(1), IntArg(limit), IntArg(0), IntArg(11), IntArg(51))); AdvanceTime(30000); resp = Run({"cl.throttle", key, max_burst, count, period, "1"}); - ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(2), IntArg(-1), IntArg(31))); + ASSERT_THAT(resp.GetVec(), + ElementsAre(IntArg(0), IntArg(limit), IntArg(2), IntArg(-1), IntArg(31))); AdvanceTime(1000); resp = Run({"cl.throttle", key, max_burst, count, period, "1"}); - ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(1), IntArg(-1), IntArg(40))); + ASSERT_THAT(resp.GetVec(), + ElementsAre(IntArg(0), IntArg(limit), IntArg(1), IntArg(-1), IntArg(40))); AdvanceTime(9000); resp = Run({"cl.throttle", key, max_burst, count, period, "1"}); - ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(1), IntArg(-1), IntArg(41))); + ASSERT_THAT(resp.GetVec(), + ElementsAre(IntArg(0), IntArg(limit), IntArg(1), IntArg(-1), IntArg(41))); AdvanceTime(40000); resp = Run({"cl.throttle", key, max_burst, count, period, "1"}); - ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11))); + ASSERT_THAT(resp.GetVec(), + ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11))); AdvanceTime(15000); resp = Run({"cl.throttle", key, max_burst, count, period, "1"}); - ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11))); + ASSERT_THAT(resp.GetVec(), + ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11))); // Zero-volume request just peeks at the state. resp = Run({"cl.throttle", key, max_burst, count, period, "0"}); - ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11))); + ASSERT_THAT(resp.GetVec(), + ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11))); // High-volume request uses up more of the limit. resp = Run({"cl.throttle", key, max_burst, count, period, "2"}); - ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(2), IntArg(-1), IntArg(31))); + ASSERT_THAT(resp.GetVec(), + ElementsAre(IntArg(0), IntArg(limit), IntArg(2), IntArg(-1), IntArg(31))); // Large requests cannot exceed limits resp = Run({"cl.throttle", key, max_burst, count, period, "5"}); - ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(1), IntArg(limit), IntArg(2), IntArg(31), IntArg(31))); + ASSERT_THAT(resp.GetVec(), + ElementsAre(IntArg(1), IntArg(limit), IntArg(2), IntArg(31), IntArg(31))); } } // namespace dfly From 5741ee186d90239139de7d810060cbe4b34a9f36 Mon Sep 17 00:00:00 2001 From: Daria Sukhonina Date: Tue, 24 Jan 2023 13:12:59 +0300 Subject: [PATCH 7/9] Port last cl.throttle test from redis-cell --- src/server/string_family.cc | 2 +- src/server/string_family_test.cc | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/server/string_family.cc b/src/server/string_family.cc index cbe67392bfbf..73430d0c6dda 100644 --- a/src/server/string_family.cc +++ b/src/server/string_family.cc @@ -1297,7 +1297,7 @@ void StringFamily::ClThrottle(CmdArgList args, ConnectionContext* cntx) { const int64_t emission_interval_ms = period * 1000 / count; if (emission_interval_ms == 0) { - return (*cntx)->SendError("Zero rates are not supported"); + return (*cntx)->SendError("zero rates are not supported"); } auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult> { diff --git a/src/server/string_family_test.cc b/src/server/string_family_test.cc index fabe485d7871..19c4638b0671 100644 --- a/src/server/string_family_test.cc +++ b/src/server/string_family_test.cc @@ -633,6 +633,9 @@ TEST_F(StringFamilyTest, ClThrottle) { resp = Run({"cl.throttle", key, max_burst, count, period, "5"}); ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(1), IntArg(limit), IntArg(2), IntArg(31), IntArg(31))); + + // Zero rates aren't supported + EXPECT_THAT(Run({"cl.throttle", "bar", "10", "1", "0"}), ErrArg("zero rates are not supported")); } } // namespace dfly From 073b98812452e203744f2cd9ab7be63f070fede6 Mon Sep 17 00:00:00 2001 From: Daria Sukhonina Date: Tue, 24 Jan 2023 14:57:00 +0300 Subject: [PATCH 8/9] add integer overflow checks to cl.throttle impl Also add one test when `count == 0` --- src/server/string_family.cc | 93 ++++++++++++++++++++++++-------- src/server/string_family_test.cc | 24 ++++++++- 2 files changed, 93 insertions(+), 24 deletions(-) diff --git a/src/server/string_family.cc b/src/server/string_family.cc index 73430d0c6dda..06baba283291 100644 --- a/src/server/string_family.cc +++ b/src/server/string_family.cc @@ -14,6 +14,7 @@ extern "C" { #include #include #include +#include #include #include "base/logging.h" @@ -359,18 +360,26 @@ OpResult SetGeneric(ConnectionContext* cntx, const SetCmd::SetParams& spar return cntx->transaction->ScheduleSingleHop(std::move(cb)); } +// emission_interval_ms assumed to be positive +// limit is assumed to be positive OpResult> OpThrottle(const OpArgs& op_args, const string_view key, - const uint64_t limit, const int64_t emission_interval_ms, + const int64_t limit, const int64_t emission_interval_ms, const uint64_t quantity) { auto& db_slice = op_args.shard->db_slice(); - const int64_t delay_variation_tolerance_ms = emission_interval_ms * limit; + if (emission_interval_ms > INT64_MAX / limit) { + return OpStatus::INVALID_INT; + } + const int64_t delay_variation_tolerance_ms = emission_interval_ms * limit; // should be positive int64_t remaining = 0; int64_t reset_after_ms = -1000; int64_t retry_after_ms = -1000; - const int64_t increment_ms = emission_interval_ms * quantity; + if (quantity != 0 && static_cast(emission_interval_ms) > INT64_MAX / quantity) { + return OpStatus::INVALID_INT; + } + const int64_t increment_ms = emission_interval_ms * quantity; // should be nonnegative auto [it, e_it] = db_slice.FindExt(op_args.db_cntx, key); const int64_t now_ms = op_args.db_cntx.time_now_ms; @@ -387,21 +396,54 @@ OpResult> OpThrottle(const OpArgs& op_args, const string_view } tat_ms = *opt_prev; } - const int64_t new_tat_ms = max(tat_ms, now_ms) + increment_ms; + int64_t new_tat_ms = max(tat_ms, now_ms); + if (new_tat_ms > INT64_MAX - increment_ms) { + return OpStatus::INVALID_INT; + } + new_tat_ms += increment_ms; + + if (new_tat_ms < INT64_MIN + delay_variation_tolerance_ms) { + return OpStatus::INVALID_INT; + } const int64_t allow_at_ms = new_tat_ms - delay_variation_tolerance_ms; + + if (allow_at_ms >= 0 ? now_ms < INT64_MIN + allow_at_ms : now_ms > INT64_MAX + allow_at_ms) { + return OpStatus::INVALID_INT; + } const int64_t diff_ms = now_ms - allow_at_ms; const bool limited = diff_ms < 0; int64_t ttl_ms; if (limited) { if (increment_ms <= delay_variation_tolerance_ms) { + if (diff_ms == INT64_MIN) { + return OpStatus::INVALID_INT; + } retry_after_ms = -diff_ms; } + + if (now_ms >= 0 ? tat_ms < INT64_MIN + now_ms : tat_ms > INT64_MAX + now_ms) { + return OpStatus::INVALID_INT; + } ttl_ms = tat_ms - now_ms; } else { + if (now_ms >= 0 ? new_tat_ms < INT64_MIN + now_ms : new_tat_ms > INT64_MAX + now_ms) { + return OpStatus::INVALID_INT; + } ttl_ms = new_tat_ms - now_ms; + } + + if (ttl_ms < delay_variation_tolerance_ms - INT64_MAX) { + return OpStatus::INVALID_INT; + } + const int64_t next_ms = delay_variation_tolerance_ms - ttl_ms; + if (next_ms > -emission_interval_ms) { + remaining = next_ms / emission_interval_ms; + } + reset_after_ms = ttl_ms; + if (!limited) { if (IsValid(it)) { if (IsValid(e_it)) { e_it->second = db_slice.FromAbsoluteTime(new_tat_ms); @@ -425,23 +467,7 @@ OpResult> OpThrottle(const OpArgs& op_args, const string_view } } - const int64_t next_ms = delay_variation_tolerance_ms - ttl_ms; - if (next_ms > -emission_interval_ms) { - remaining = next_ms / emission_interval_ms; - } - reset_after_ms = ttl_ms; - - int64_t retry_after_s = retry_after_ms / 1000; - if (retry_after_ms > 0) { - retry_after_s += 1; - } - - int64_t reset_after_s = reset_after_ms / 1000; - if (reset_after_ms > 0) { - reset_after_s += 1; - } - - return array{limited ? 1 : 0, limit, remaining, retry_after_s, reset_after_s}; + return array{limited ? 1 : 0, limit, remaining, retry_after_ms, reset_after_ms}; } } // namespace @@ -1293,7 +1319,14 @@ void StringFamily::ClThrottle(CmdArgList args, ConnectionContext* cntx) { } } - const uint64_t limit = max_burst + 1; + if (max_burst > INT64_MAX - 1) { + return (*cntx)->SendError(kInvalidIntErr); + } + const int64_t limit = max_burst + 1; + + if (period > UINT64_MAX / 1000 || count == 0 || period * 1000 / count > INT64_MAX) { + return (*cntx)->SendError(kInvalidIntErr); + } const int64_t emission_interval_ms = period * 1000 / count; if (emission_interval_ms == 0) { @@ -1309,7 +1342,20 @@ void StringFamily::ClThrottle(CmdArgList args, ConnectionContext* cntx) { if (result) { (*cntx)->StartArray(result->size()); - const auto& array = result.value(); + auto& array = result.value(); + + int64_t retry_after_s = array[3] / 1000; + if (array[3] > 0) { + retry_after_s += 1; + } + array[3] = retry_after_s; + + int64_t reset_after_s = array[4] / 1000; + if (array[4] > 0) { + reset_after_s += 1; + } + array[4] = reset_after_s; + for (const auto& v : array) { (*cntx)->SendLong(v); } @@ -1318,6 +1364,7 @@ void StringFamily::ClThrottle(CmdArgList args, ConnectionContext* cntx) { case OpStatus::WRONG_TYPE: (*cntx)->SendError(kWrongTypeErr); break; + case OpStatus::INVALID_INT: case OpStatus::INVALID_VALUE: (*cntx)->SendError(kInvalidIntErr); break; diff --git a/src/server/string_family_test.cc b/src/server/string_family_test.cc index 19c4638b0671..f1a35145cc5c 100644 --- a/src/server/string_family_test.cc +++ b/src/server/string_family_test.cc @@ -566,76 +566,98 @@ TEST_F(StringFamilyTest, ClThrottle) { // You can never make a request larger than the maximum. auto resp = Run({"cl.throttle", key, max_burst, count, period, "6"}); + ASSERT_EQ(RespExpr::ARRAY, resp.type); ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(1), IntArg(limit), IntArg(5), IntArg(-1), IntArg(0))); // Rate limit normal requests appropriately. resp = Run({"cl.throttle", key, max_burst, count, period}); + ASSERT_EQ(RespExpr::ARRAY, resp.type); ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11))); resp = Run({"cl.throttle", key, max_burst, count, period}); + ASSERT_EQ(RespExpr::ARRAY, resp.type); ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(3), IntArg(-1), IntArg(21))); resp = Run({"cl.throttle", key, max_burst, count, period}); + ASSERT_EQ(RespExpr::ARRAY, resp.type); ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(2), IntArg(-1), IntArg(31))); resp = Run({"cl.throttle", key, max_burst, count, period}); + ASSERT_EQ(RespExpr::ARRAY, resp.type); ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(1), IntArg(-1), IntArg(41))); resp = Run({"cl.throttle", key, max_burst, count, period}); + ASSERT_EQ(RespExpr::ARRAY, resp.type); ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(0), IntArg(-1), IntArg(51))); resp = Run({"cl.throttle", key, max_burst, count, period}); + ASSERT_EQ(RespExpr::ARRAY, resp.type); ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(1), IntArg(limit), IntArg(0), IntArg(11), IntArg(51))); AdvanceTime(30000); resp = Run({"cl.throttle", key, max_burst, count, period, "1"}); + ASSERT_EQ(RespExpr::ARRAY, resp.type); ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(2), IntArg(-1), IntArg(31))); AdvanceTime(1000); resp = Run({"cl.throttle", key, max_burst, count, period, "1"}); + ASSERT_EQ(RespExpr::ARRAY, resp.type); ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(1), IntArg(-1), IntArg(40))); AdvanceTime(9000); resp = Run({"cl.throttle", key, max_burst, count, period, "1"}); + ASSERT_EQ(RespExpr::ARRAY, resp.type); ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(1), IntArg(-1), IntArg(41))); AdvanceTime(40000); resp = Run({"cl.throttle", key, max_burst, count, period, "1"}); + ASSERT_EQ(RespExpr::ARRAY, resp.type); ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11))); AdvanceTime(15000); resp = Run({"cl.throttle", key, max_burst, count, period, "1"}); + ASSERT_EQ(RespExpr::ARRAY, resp.type); ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11))); // Zero-volume request just peeks at the state. resp = Run({"cl.throttle", key, max_burst, count, period, "0"}); + ASSERT_EQ(RespExpr::ARRAY, resp.type); ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11))); // High-volume request uses up more of the limit. resp = Run({"cl.throttle", key, max_burst, count, period, "2"}); + ASSERT_EQ(RespExpr::ARRAY, resp.type); ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(2), IntArg(-1), IntArg(31))); // Large requests cannot exceed limits resp = Run({"cl.throttle", key, max_burst, count, period, "5"}); + ASSERT_EQ(RespExpr::ARRAY, resp.type); ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(1), IntArg(limit), IntArg(2), IntArg(31), IntArg(31))); // Zero rates aren't supported - EXPECT_THAT(Run({"cl.throttle", "bar", "10", "1", "0"}), ErrArg("zero rates are not supported")); + resp = Run({"cl.throttle", "bar", "10", "1", "0"}); + ASSERT_EQ(RespExpr::ERROR, resp.type); + EXPECT_THAT(resp, ErrArg("zero rates are not supported")); + + // count == 0 + resp = Run({"cl.throttle", "bar", "10", "0", "1"}); + ASSERT_EQ(RespExpr::ERROR, resp.type); + EXPECT_THAT(resp, ErrArg(kInvalidIntErr)); } } // namespace dfly From 290b71a1473efebb47aa2c38abc36215a71877a9 Mon Sep 17 00:00:00 2001 From: Daria Sukhonina Date: Tue, 24 Jan 2023 14:58:31 +0300 Subject: [PATCH 9/9] add a comment about cl.throttle response --- src/server/string_family.cc | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/server/string_family.cc b/src/server/string_family.cc index 06baba283291..4f6f3ba38f5b 100644 --- a/src/server/string_family.cc +++ b/src/server/string_family.cc @@ -1285,6 +1285,18 @@ auto StringFamily::OpMGet(bool fetch_mcflag, bool fetch_mcver, const Transaction } /* CL.THROTTLE [] */ +/* Response is array of 5 integers. The meaning of each array item is: + * 1. Whether the action was limited: + * - 0 indicates the action is allowed. + * - 1 indicates that the action was limited/blocked. + * 2. The total limit of the key (max_burst + 1). This is equivalent to the common + * X-RateLimit-Limit HTTP header. + * 3. The remaining limit of the key. Equivalent to X-RateLimit-Remaining. + * 4. The number of seconds until the user should retry, and always -1 if the action was allowed. + * Equivalent to Retry-After. + * 5. The number of seconds until the limit will reset to its maximum capacity. Equivalent to + * X-RateLimit-Reset. + */ void StringFamily::ClThrottle(CmdArgList args, ConnectionContext* cntx) { const string_view key = ArgS(args, 1);