Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(string family): implement cl.throttle #714

Merged
merged 10 commits into from
Jan 24, 2023
160 changes: 159 additions & 1 deletion src/server/string_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ extern "C" {
#include <absl/container/inlined_vector.h>
#include <double-conversion/string-to-double.h>

#include <algorithm>
#include <array>
#include <chrono>
#include <tuple>
romange marked this conversation as resolved.
Show resolved Hide resolved

#include "base/logging.h"
#include "redis/util.h"
Expand Down Expand Up @@ -356,6 +359,91 @@ OpResult<void> SetGeneric(ConnectionContext* cntx, const SetCmd::SetParams& spar
return cntx->transaction->ScheduleSingleHop(std::move(cb));
}

OpResult<array<int64_t, 5>> OpThrottle(const OpArgs& op_args, const string_view key,
const uint64_t limit, const int64_t emission_interval_ms,
const uint64_t quantity) {
auto& db_slice = op_args.shard->db_slice();

const int64_t delay_variation_tolerance_ms = emission_interval_ms * limit;

int64_t remaining = 0;
int64_t reset_after_ms = -1000;
int64_t retry_after_ms = -1000;

const int64_t increment_ms = emission_interval_ms * quantity;

auto [it, e_it] = db_slice.FindExt(op_args.db_cntx, key);
const int64_t now_ms = op_args.db_cntx.time_now_ms;

int64_t tat_ms = now_ms;
if (IsValid(it)) {
if (it->second.ObjType() != OBJ_STRING) {
return OpStatus::WRONG_TYPE;
}

auto opt_prev = it->second.TryGetInt();
if (!opt_prev) {
return OpStatus::INVALID_VALUE;
}
tat_ms = *opt_prev;
}
const int64_t new_tat_ms = max(tat_ms, now_ms) + increment_ms;

const int64_t allow_at_ms = new_tat_ms - delay_variation_tolerance_ms;
const int64_t diff_ms = now_ms - allow_at_ms;

const bool limited = diff_ms < 0;
int64_t ttl_ms;
if (limited) {
if (increment_ms <= delay_variation_tolerance_ms) {
retry_after_ms = -diff_ms;
}
ttl_ms = tat_ms - now_ms;
} else {
ttl_ms = new_tat_ms - now_ms;

if (IsValid(it)) {
if (IsValid(e_it)) {
e_it->second = db_slice.FromAbsoluteTime(new_tat_ms);
} else {
db_slice.AddExpire(op_args.db_cntx.db_index, it, new_tat_ms);
}

db_slice.PreUpdate(op_args.db_cntx.db_index, it);
it->second.SetInt(new_tat_ms);
db_slice.PostUpdate(op_args.db_cntx.db_index, it, key);
} else {
CompactObj cobj;
cobj.SetInt(new_tat_ms);

// AddNew calls PostUpdate inside.
try {
it = db_slice.AddNew(op_args.db_cntx, key, std::move(cobj), new_tat_ms);
} catch (bad_alloc&) {
return OpStatus::OUT_OF_MEMORY;
}
}
}

const int64_t next_ms = delay_variation_tolerance_ms - ttl_ms;
if (next_ms > -emission_interval_ms) {
remaining = next_ms / emission_interval_ms;
}
reset_after_ms = ttl_ms;

int64_t retry_after_s = retry_after_ms / 1000;
romange marked this conversation as resolved.
Show resolved Hide resolved
if (retry_after_ms > 0) {
retry_after_s += 1;
}

int64_t reset_after_s = reset_after_ms / 1000;
if (reset_after_ms > 0) {
reset_after_s += 1;
}

return array<int64_t, 5>{limited ? 1 : 0, limit, remaining, retry_after_s, reset_after_s};
zetanumbers marked this conversation as resolved.
Show resolved Hide resolved
}

} // namespace

OpStatus SetCmd::Set(const SetParams& params, string_view key, string_view value) {
Expand Down Expand Up @@ -1170,6 +1258,75 @@ auto StringFamily::OpMGet(bool fetch_mcflag, bool fetch_mcver, const Transaction
return response;
}

/* CL.THROTTLE <key> <max_burst> <count per period> <period> [<quantity>] */
void StringFamily::ClThrottle(CmdArgList args, ConnectionContext* cntx) {
zetanumbers marked this conversation as resolved.
Show resolved Hide resolved
const string_view key = ArgS(args, 1);

// Allow max burst in number of tokens
uint64_t max_burst;
const string_view max_burst_str = ArgS(args, 2);
if (!absl::SimpleAtoi(max_burst_str, &max_burst)) {
zetanumbers marked this conversation as resolved.
Show resolved Hide resolved
return (*cntx)->SendError(kInvalidIntErr);
}

// Emit count of tokens per period
uint64_t count;
const string_view count_str = ArgS(args, 3);
if (!absl::SimpleAtoi(count_str, &count)) {
zetanumbers marked this conversation as resolved.
Show resolved Hide resolved
return (*cntx)->SendError(kInvalidIntErr);
}

// Period of emitting count of tokens
uint64_t period;
const string_view period_str = ArgS(args, 4);
if (!absl::SimpleAtoi(period_str, &period)) {
zetanumbers marked this conversation as resolved.
Show resolved Hide resolved
return (*cntx)->SendError(kInvalidIntErr);
}

// Apply quantity of tokens now
uint64_t quantity = 1;
if (args.size() > 5) {
const string_view quantity_str = ArgS(args, 5);

if (!absl::SimpleAtoi(quantity_str, &quantity)) {
return (*cntx)->SendError(kInvalidIntErr);
}
}

const uint64_t limit = max_burst + 1;
const int64_t emission_interval_ms = period * 1000 / count;
romange marked this conversation as resolved.
Show resolved Hide resolved

if (emission_interval_ms == 0) {
return (*cntx)->SendError("Zero rates are not supported");
}

auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult<array<int64_t, 5>> {
return OpThrottle(t->GetOpArgs(shard), key, limit, emission_interval_ms, quantity);
};

Transaction* trans = cntx->transaction;
OpResult<array<int64_t, 5>> result = trans->ScheduleSingleHopT(std::move(cb));

switch (result.status()) {
case OpStatus::WRONG_TYPE:
(*cntx)->SendError(kWrongTypeErr);
break;
case OpStatus::INVALID_VALUE:
(*cntx)->SendError(kInvalidIntErr);
break;
case OpStatus::OUT_OF_MEMORY:
(*cntx)->SendError(kOutOfMemory);
break;
default:
(*cntx)->StartArray(result->size());
zetanumbers marked this conversation as resolved.
Show resolved Hide resolved
const auto& array = result.value();
for (const auto& v : array) {
(*cntx)->SendLong(v);
}
break;
}
}

void StringFamily::Init(util::ProactorPool* pp) {
set_qps.Init(pp);
get_qps.Init(pp);
Expand Down Expand Up @@ -1206,7 +1363,8 @@ void StringFamily::Register(CommandRegistry* registry) {
<< CI{"STRLEN", CO::READONLY | CO::FAST, 2, 1, 1, 1}.HFUNC(StrLen)
<< CI{"GETRANGE", CO::READONLY | CO::FAST, 4, 1, 1, 1}.HFUNC(GetRange)
<< CI{"SUBSTR", CO::READONLY | CO::FAST, 4, 1, 1, 1}.HFUNC(GetRange) // Alias for GetRange
<< CI{"SETRANGE", CO::WRITE | CO::FAST | CO::DENYOOM, 4, 1, 1, 1}.HFUNC(SetRange);
<< CI{"SETRANGE", CO::WRITE | CO::FAST | CO::DENYOOM, 4, 1, 1, 1}.HFUNC(SetRange)
<< CI{"CL.THROTTLE", CO::WRITE | CO::DENYOOM | CO::FAST, -5, 1, 1, 1}.HFUNC(ClThrottle);
}

} // namespace dfly
2 changes: 2 additions & 0 deletions src/server/string_family.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class StringFamily {
static void Prepend(CmdArgList args, ConnectionContext* cntx);
static void PSetEx(CmdArgList args, ConnectionContext* cntx);

static void ClThrottle(CmdArgList args, ConnectionContext* cntx);

// These functions are used internally, they do not implement any specific command
static void IncrByGeneric(std::string_view key, int64_t val, ConnectionContext* cntx);
static void ExtendGeneric(CmdArgList args, bool prepend, ConnectionContext* cntx);
Expand Down
63 changes: 63 additions & 0 deletions src/server/string_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -557,4 +557,67 @@ TEST_F(StringFamilyTest, GetEx) {
EXPECT_THAT(Run({"getex", "foo"}), ArgType(RespExpr::NIL));
}

TEST_F(StringFamilyTest, ClThrottle) {
const int64_t limit = 5;
const char* const key = "foo";
const char* const max_burst = "4"; // limit - 1
const char* const count = "1";
const char* const period = "10";

// You can never make a request larger than the maximum.
auto resp = Run({"cl.throttle", key, max_burst, count, period, "6"});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(1), IntArg(limit), IntArg(5), IntArg(-1), IntArg(0)));

// Rate limit normal requests appropriately.
resp = Run({"cl.throttle", key, max_burst, count, period});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11)));

resp = Run({"cl.throttle", key, max_burst, count, period});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(3), IntArg(-1), IntArg(21)));

resp = Run({"cl.throttle", key, max_burst, count, period});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(2), IntArg(-1), IntArg(31)));

resp = Run({"cl.throttle", key, max_burst, count, period});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(1), IntArg(-1), IntArg(41)));

resp = Run({"cl.throttle", key, max_burst, count, period});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(0), IntArg(-1), IntArg(51)));

resp = Run({"cl.throttle", key, max_burst, count, period});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(1), IntArg(limit), IntArg(0), IntArg(11), IntArg(51)));

AdvanceTime(30000);
resp = Run({"cl.throttle", key, max_burst, count, period, "1"});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(2), IntArg(-1), IntArg(31)));

AdvanceTime(1000);
resp = Run({"cl.throttle", key, max_burst, count, period, "1"});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(1), IntArg(-1), IntArg(40)));

AdvanceTime(9000);
resp = Run({"cl.throttle", key, max_burst, count, period, "1"});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(1), IntArg(-1), IntArg(41)));

AdvanceTime(40000);
resp = Run({"cl.throttle", key, max_burst, count, period, "1"});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11)));

AdvanceTime(15000);
resp = Run({"cl.throttle", key, max_burst, count, period, "1"});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11)));

// Zero-volume request just peeks at the state.
resp = Run({"cl.throttle", key, max_burst, count, period, "0"});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11)));

// High-volume request uses up more of the limit.
resp = Run({"cl.throttle", key, max_burst, count, period, "2"});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(2), IntArg(-1), IntArg(31)));

// Large requests cannot exceed limits
resp = Run({"cl.throttle", key, max_burst, count, period, "5"});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(1), IntArg(limit), IntArg(2), IntArg(31), IntArg(31)));
}

} // namespace dfly