diff --git a/src/server/family_utils.h b/src/server/family_utils.h index 2fdd9920a703..248d3fb47f52 100644 --- a/src/server/family_utils.h +++ b/src/server/family_utils.h @@ -9,11 +9,33 @@ #include +#include "facade/facade_types.h" + extern "C" { #include "redis/sds.h" } namespace dfly { +template +static std::vector ExpireElements(DenseSet* owner, const facade::CmdArgList values, + uint32_t ttl_sec) { + std::vector res; + res.reserve(values.size()); + + for (size_t i = 0; i < values.size(); i++) { + std::string_view field = facade::ToSV(values[i]); + auto it = owner->Find(field); + if (it != owner->end()) { + it.SetExpiryTime(ttl_sec); + res.emplace_back(ttl_sec == 0 ? 0 : 1); + } else { + res.emplace_back(-2); + } + } + + return res; +} + // Copy str to thread local sds instance. Valid until next WrapSds call on thread sds WrapSds(std::string_view str); diff --git a/src/server/generic_family.cc b/src/server/generic_family.cc index a02520bfddc8..1d07c613f8bc 100644 --- a/src/server/generic_family.cc +++ b/src/server/generic_family.cc @@ -7,6 +7,7 @@ #include #include +#include "facade/cmd_arg_parser.h" #include "facade/reply_builder.h" extern "C" { @@ -44,6 +45,7 @@ using namespace facade; namespace { +constexpr uint32_t kMaxTtl = (1UL << 26); constexpr size_t DUMP_FOOTER_SIZE = sizeof(uint64_t) + sizeof(uint16_t); // version number and crc std::optional GetRdbVersion(std::string_view msg) { @@ -672,6 +674,24 @@ OpStatus OpExpire(const OpArgs& op_args, string_view key, const DbSlice::ExpireP return res.status(); } +OpResult> OpFieldExpire(const OpArgs& op_args, string_view key, uint32_t ttl_sec, + CmdArgList values) { + auto& db_slice = op_args.GetDbSlice(); + auto [it, expire_it, auto_updater] = db_slice.FindMutable(op_args.db_cntx, key); + + if (!IsValid(it) || (it->second.ObjType() != OBJ_SET && it->second.ObjType() != OBJ_HASH)) { + std::vector res(values.size(), -2); + return res; + } + + PrimeValue* pv = &it->second; + if (pv->ObjType() == OBJ_SET) { + return SetFamily::SetFieldsExpireTime(op_args, ttl_sec, values, pv); + } else { + return HSetFamily::SetFieldsExpireTime(op_args, ttl_sec, key, values, pv); + } +} + // returns -2 if the key was not found, -3 if the field was not found, // -1 if ttl on the field was not found. OpResult OpFieldTtl(Transaction* t, EngineShard* shard, string_view key, string_view field) { @@ -1261,6 +1281,33 @@ void GenericFamily::Restore(CmdArgList args, ConnectionContext* cntx) { } } +void GenericFamily::FieldExpire(CmdArgList args, ConnectionContext* cntx) { + CmdArgParser parser{args}; + string_view key = parser.Next(); + string_view ttl_str = parser.Next(); + uint32_t ttl_sec; + if (!absl::SimpleAtoi(ttl_str, &ttl_sec) || ttl_sec == 0 || ttl_sec > kMaxTtl) { + return cntx->SendError(kInvalidIntErr); + } + CmdArgList fields = parser.Tail(); + + auto cb = [&](Transaction* t, EngineShard* shard) { + return OpFieldExpire(t->GetOpArgs(shard), key, ttl_sec, fields); + }; + + OpResult> result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); + auto* rb = static_cast(cntx->reply_builder()); + if (result) { + rb->StartArray(result->size()); + const auto& array = result.value(); + for (const auto& v : array) { + rb->SendLong(v); + } + } else { + cntx->SendError(result.status()); + } +} + // Returns -2 if key not found, WRONG_TYPE if key is not a set or hash // -1 if the field does not have associated TTL on it, and -3 if field is not found. void GenericFamily::FieldTtl(CmdArgList args, ConnectionContext* cntx) { @@ -1763,6 +1810,7 @@ constexpr uint32_t kMove = KEYSPACE | WRITE | FAST; constexpr uint32_t kRestore = KEYSPACE | WRITE | SLOW | DANGEROUS; constexpr uint32_t kExpireTime = KEYSPACE | READ | FAST; constexpr uint32_t kPExpireTime = KEYSPACE | READ | FAST; +constexpr uint32_t kFieldExpire = WRITE | HASH | SET | FAST; } // namespace acl void GenericFamily::Register(CommandRegistry* registry) { @@ -1788,6 +1836,8 @@ void GenericFamily::Register(CommandRegistry* registry) { PexpireAt) << CI{"PEXPIRE", CO::WRITE | CO::FAST | CO::NO_AUTOJOURNAL, 3, 1, 1, acl::kPExpire}.HFUNC( Pexpire) + << CI{"FIELDEXPIRE", CO::WRITE | CO::FAST | CO::DENYOOM, -4, 1, 1, acl::kFieldExpire}.HFUNC( + FieldExpire) << CI{"RENAME", CO::WRITE | CO::NO_AUTOJOURNAL, 3, 1, 2, acl::kRename}.HFUNC(Rename) << CI{"RENAMENX", CO::WRITE | CO::NO_AUTOJOURNAL, 3, 1, 2, acl::kRenamNX}.HFUNC(RenameNx) << CI{"SELECT", kSelectOpts, 2, 0, 0, acl::kSelect}.HFUNC(Select) diff --git a/src/server/generic_family.h b/src/server/generic_family.h index 4e380850d10c..ad5a86d4e2dd 100644 --- a/src/server/generic_family.h +++ b/src/server/generic_family.h @@ -71,6 +71,7 @@ class GenericFamily { static void Restore(CmdArgList args, ConnectionContext* cntx); static void RandomKey(CmdArgList args, ConnectionContext* cntx); static void FieldTtl(CmdArgList args, ConnectionContext* cntx); + static void FieldExpire(CmdArgList args, ConnectionContext* cntx); static ErrorReply RenameGeneric(CmdArgList args, bool destination_should_not_exist, ConnectionContext* cntx); diff --git a/src/server/generic_family_test.cc b/src/server/generic_family_test.cc index 212c84c80fe6..7e877dc53c41 100644 --- a/src/server/generic_family_test.cc +++ b/src/server/generic_family_test.cc @@ -795,6 +795,38 @@ TEST_F(GenericFamilyTest, JsonType) { ASSERT_THAT(vec, ElementsAre("json")); } +TEST_F(GenericFamilyTest, FieldExpireSet) { + Run({"SADD", "key", "a", "b", "c"}); + EXPECT_THAT(Run({"FIELDEXPIRE", "key", "10", "a", "b", "c"}), + RespArray(ElementsAre(IntArg(1), IntArg(1), IntArg(1)))); + AdvanceTime(10'000); + EXPECT_THAT(Run({"SMEMBERS", "key"}), RespArray(ElementsAre())); +} + +TEST_F(GenericFamilyTest, FieldExpireHset) { + for (int i = 0; i < 3; ++i) { + EXPECT_EQ(CheckedInt({"HSET", "key", absl::StrCat("k", i), "v"}), 1); + } + EXPECT_THAT(Run({"FIELDEXPIRE", "key", "10", "k0", "k1", "k2"}), + RespArray(ElementsAre(IntArg(1), IntArg(1), IntArg(1)))); + AdvanceTime(10'000); + EXPECT_THAT(Run({"HGETALL", "key"}), RespArray(ElementsAre())); +} + +TEST_F(GenericFamilyTest, FieldExpireNoSuchField) { + EXPECT_EQ(CheckedInt({"SADD", "key", "a"}), 1); + EXPECT_EQ(CheckedInt({"HSET", "key2", "k0", "v0"}), 1); + EXPECT_THAT(Run({"FIELDEXPIRE", "key", "10", "a", "b"}), + RespArray(ElementsAre(IntArg(1), IntArg(-2)))); + EXPECT_THAT(Run({"FIELDEXPIRE", "key2", "10", "k0", "b"}), + RespArray(ElementsAre(IntArg(1), IntArg(-2)))); +} + +TEST_F(GenericFamilyTest, FieldExpireNoSuchKey) { + EXPECT_THAT(Run({"FIELDEXPIRE", "key", "10", "a", "b"}), + RespArray(ElementsAre(IntArg(-2), IntArg(-2)))); +} + TEST_F(GenericFamilyTest, ExpireTime) { EXPECT_EQ(-2, CheckedInt({"EXPIRETIME", "foo"})); EXPECT_EQ(-2, CheckedInt({"PEXPIRETIME", "foo"})); diff --git a/src/server/hset_family.cc b/src/server/hset_family.cc index 0883d4f1cf91..65abbb946141 100644 --- a/src/server/hset_family.cc +++ b/src/server/hset_family.cc @@ -4,6 +4,8 @@ #include "server/hset_family.h" +#include "server/family_utils.h" + extern "C" { #include "redis/listpack.h" #include "redis/redis_aux.h" @@ -725,6 +727,23 @@ void HGetGeneric(CmdArgList args, ConnectionContext* cntx, uint8_t getall_mask) } } +OpResult> OpHExpire(const OpArgs& op_args, string_view key, uint32_t ttl_sec, + CmdArgList values) { + auto& db_slice = op_args.GetDbSlice(); + auto op_res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_HASH); + + if (!op_res) { + if (op_res.status() == OpStatus::KEY_NOTFOUND) { + std::vector res(values.size(), -2); + return res; + } + return op_res.status(); + } + + PrimeValue* pv = &((*op_res).it->second); + return HSetFamily::SetFieldsExpireTime(op_args, ttl_sec, key, values, pv); +} + // HSETEX key [NX] tll_sec field value field value ... void HSetEx(CmdArgList args, ConnectionContext* cntx) { CmdArgParser parser{args}; @@ -808,6 +827,49 @@ void HSetFamily::HExists(CmdArgList args, ConnectionContext* cntx) { } } +void HSetFamily::HExpire(CmdArgList args, ConnectionContext* cntx) { + CmdArgParser parser{args}; + string_view key = parser.Next(); + string_view ttl_str = parser.Next(); + uint32_t ttl_sec; + constexpr uint32_t kMaxTtl = (1UL << 26); + if (!absl::SimpleAtoi(ttl_str, &ttl_sec) || ttl_sec == 0 || ttl_sec > kMaxTtl) { + return cntx->SendError(kInvalidIntErr); + } + if (!static_cast(parser.Check("FIELDS"sv))) { + return cntx->SendError("Mandatory argument FIELDS is missing or not at the right position", + kSyntaxErrType); + } + + string_view numFieldsStr = parser.Next(); + uint32_t numFields; + if (!absl::SimpleAtoi(numFieldsStr, &numFields) || numFields == 0) { + return cntx->SendError(kInvalidIntErr); + } + + CmdArgList fields = parser.Tail(); + if (fields.size() != numFields) { + return cntx->SendError("The `numfields` parameter must match the number of arguments", + kSyntaxErrType); + } + + auto cb = [&](Transaction* t, EngineShard* shard) { + return OpHExpire(t->GetOpArgs(shard), key, ttl_sec, fields); + }; + + OpResult> result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); + auto* rb = static_cast(cntx->reply_builder()); + if (result) { + rb->StartArray(result->size()); + const auto& array = result.value(); + for (const auto& v : array) { + rb->SendLong(v); + } + } else { + cntx->SendError(result.status()); + } +} + void HSetFamily::HMGet(CmdArgList args, ConnectionContext* cntx) { string_view key = ArgS(args, 0); @@ -1189,6 +1251,7 @@ constexpr uint32_t kHSet = WRITE | HASH | FAST; constexpr uint32_t kHSetEx = WRITE | HASH | FAST; constexpr uint32_t kHSetNx = WRITE | HASH | FAST; constexpr uint32_t kHStrLen = READ | HASH | FAST; +constexpr uint32_t kHExpire = WRITE | HASH | FAST; constexpr uint32_t kHVals = READ | HASH | SLOW; } // namespace acl @@ -1206,6 +1269,7 @@ void HSetFamily::Register(CommandRegistry* registry) { << CI{"HINCRBYFLOAT", CO::WRITE | CO::DENYOOM | CO::FAST, 4, 1, 1, acl::kHIncrByFloat}.HFUNC( HIncrByFloat) << CI{"HKEYS", CO::READONLY, 2, 1, 1, acl::kHKeys}.HFUNC(HKeys) + << CI{"HEXPIRE", CO::WRITE | CO::FAST | CO::DENYOOM, -5, 1, 1, acl::kHExpire}.HFUNC(HExpire) << CI{"HRANDFIELD", CO::READONLY, -2, 1, 1, acl::kHRandField}.HFUNC(HRandField) << CI{"HSCAN", CO::READONLY, -3, 1, 1, acl::kHScan}.HFUNC(HScan) << CI{"HSET", CO::WRITE | CO::FAST | CO::DENYOOM, -4, 1, 1, acl::kHSet}.HFUNC(HSet) @@ -1276,4 +1340,27 @@ int32_t HSetFamily::FieldExpireTime(const DbContext& db_context, const PrimeValu } } +vector HSetFamily::SetFieldsExpireTime(const OpArgs& op_args, uint32_t ttl_sec, + string_view key, CmdArgList values, PrimeValue* pv) { + DCHECK_EQ(OBJ_HASH, pv->ObjType()); + op_args.shard->search_indices()->RemoveDoc(key, op_args.db_cntx, *pv); + + if (pv->Encoding() == kEncodingListPack) { + // a valid result can never be a listpack, since it doesnt keep ttl + uint8_t* lp = (uint8_t*)pv->RObjPtr(); + auto& db_slice = op_args.GetDbSlice(); + DbTableStats* stats = db_slice.MutableStats(op_args.db_cntx.db_index); + stats->listpack_bytes -= lpBytes(lp); + stats->listpack_blob_cnt--; + StringMap* sm = HSetFamily::ConvertToStrMap(lp); + pv->InitRobj(OBJ_HASH, kEncodingStrMap2, sm); + } + + // This needs to be explicitly fetched again since the pv might have changed. + StringMap* sm = container_utils::GetStringMap(*pv, op_args.db_cntx); + vector res = ExpireElements(sm, values, ttl_sec); + op_args.shard->search_indices()->AddDoc(key, op_args.db_cntx, *pv); + return res; +} + } // namespace dfly diff --git a/src/server/hset_family.h b/src/server/hset_family.h index 47969b05ae60..296051171607 100644 --- a/src/server/hset_family.h +++ b/src/server/hset_family.h @@ -29,9 +29,14 @@ class HSetFamily { static int32_t FieldExpireTime(const DbContext& db_context, const PrimeValue& pv, std::string_view field); + static std::vector SetFieldsExpireTime(const OpArgs& op_args, uint32_t ttl_sec, + std::string_view key, CmdArgList values, + PrimeValue* pv); + private: // TODO: to move it to anonymous namespace in cc file. + static void HExpire(CmdArgList args, ConnectionContext* cntx); static void HDel(CmdArgList args, ConnectionContext* cntx); static void HLen(CmdArgList args, ConnectionContext* cntx); static void HExists(CmdArgList args, ConnectionContext* cntx); diff --git a/src/server/hset_family_test.cc b/src/server/hset_family_test.cc index 088c3e4a6a50..0c8bc031cf2a 100644 --- a/src/server/hset_family_test.cc +++ b/src/server/hset_family_test.cc @@ -119,6 +119,12 @@ TEST_F(HSetFamilyTest, HIncr) { EXPECT_THAT(resp, ErrArg("hash value is not an integer")); } +TEST_F(HSetFamilyTest, HIncrRespected) { + Run({"hset", "key", "a", "1"}); + EXPECT_EQ(11, CheckedInt({"hincrby", "key", "a", "10"})); + EXPECT_EQ(11, CheckedInt({"hget", "key", "a"})); +} + TEST_F(HSetFamilyTest, HScan) { for (int i = 0; i < 10; i++) { Run({"HSET", "myhash", absl::StrCat("Field-", i), absl::StrCat("Value-", i)}); @@ -383,6 +389,44 @@ TEST_F(HSetFamilyTest, Issue2102) { EXPECT_THAT(Run({"HGETALL", "key"}), RespArray(ElementsAre())); } +TEST_F(HSetFamilyTest, HExpire) { + EXPECT_EQ(CheckedInt({"HSET", "key", "k0", "v0", "k1", "v1", "k2", "v2"}), 3); + EXPECT_THAT(Run({"HEXPIRE", "key", "10", "FIELDS", "3", "k0", "k1", "k2"}), + RespArray(ElementsAre(IntArg(1), IntArg(1), IntArg(1)))); + AdvanceTime(10'000); + EXPECT_THAT(Run({"HGETALL", "key"}), RespArray(ElementsAre())); + + EXPECT_EQ(CheckedInt({"HSETEX", "key2", "60", "k0", "v0", "k1", "v2"}), 2); + EXPECT_THAT(Run({"HEXPIRE", "key2", "10", "FIELDS", "2", "k0", "k1"}), + RespArray(ElementsAre(IntArg(1), IntArg(1)))); + AdvanceTime(10'000); + EXPECT_THAT(Run({"HGETALL", "key2"}), RespArray(ElementsAre())); +} + +TEST_F(HSetFamilyTest, HExpireNoExpireEarly) { + EXPECT_EQ(CheckedInt({"HSET", "key", "k0", "v0", "k1", "v1"}), 2); + EXPECT_THAT(Run({"HEXPIRE", "key", "10", "FIELDS", "2", "k0", "k1"}), + RespArray(ElementsAre(IntArg(1), IntArg(1)))); + AdvanceTime(9'000); + EXPECT_THAT(Run({"HGETALL", "key"}), RespArray(UnorderedElementsAre("k0", "v0", "k1", "v1"))); +} + +TEST_F(HSetFamilyTest, HExpireNoSuchField) { + EXPECT_EQ(CheckedInt({"HSET", "key", "k0", "v0"}), 1); + EXPECT_THAT(Run({"HEXPIRE", "key", "10", "FIELDS", "2", "k0", "k1"}), + RespArray(ElementsAre(IntArg(1), IntArg(-2)))); +} + +TEST_F(HSetFamilyTest, HExpireNoSuchKey) { + EXPECT_THAT(Run({"HEXPIRE", "key", "10", "FIELDS", "2", "k0", "k1"}), + RespArray(ElementsAre(IntArg(-2), IntArg(-2)))); +} + +TEST_F(HSetFamilyTest, HExpireNoAddNew) { + Run({"HEXPIRE", "key", "10", "FIELDS", "1", "k0"}); + EXPECT_THAT(Run({"HGETALL", "key"}), RespArray(ElementsAre())); +} + TEST_F(HSetFamilyTest, RandomFieldAllExpired) { for (int i = 0; i < 10; ++i) { EXPECT_EQ(CheckedInt({"HSETEX", "key", "10", absl::StrCat("k", i), "v"}), 1); diff --git a/src/server/set_family.cc b/src/server/set_family.cc index 77db27654d89..c914bf7c9a6f 100644 --- a/src/server/set_family.cc +++ b/src/server/set_family.cc @@ -1529,4 +1529,22 @@ int32_t SetFamily::FieldExpireTime(const DbContext& db_context, const PrimeValue return GetExpiry(db_context, st, field); } +vector SetFamily::SetFieldsExpireTime(const OpArgs& op_args, uint32_t ttl_sec, + CmdArgList values, PrimeValue* pv) { + DCHECK_EQ(OBJ_SET, pv->ObjType()); + + if (pv->Encoding() == kEncodingIntSet) { + // a valid result can never be a intset, since it doesnt keep ttl + intset* is = (intset*)pv->RObjPtr(); + StringSet* ss = SetFamily::ConvertToStrSet(is, intsetLen(is)); + if (!ss) { + std::vector out(values.size(), -2); + return out; + } + pv->InitRobj(OBJ_SET, kEncodingStrMap2, ss); + } + + return ExpireElements((StringSet*)pv->RObjPtr(), values, ttl_sec); +} + } // namespace dfly diff --git a/src/server/set_family.h b/src/server/set_family.h index 94c565bc5acd..2d9b056a5e2f 100644 --- a/src/server/set_family.h +++ b/src/server/set_family.h @@ -33,6 +33,9 @@ class SetFamily { static int32_t FieldExpireTime(const DbContext& db_context, const PrimeValue& pv, std::string_view field); + static std::vector SetFieldsExpireTime(const OpArgs& op_args, uint32_t ttl_sec, + CmdArgList values, PrimeValue* pv); + private: };