Skip to content

Commit

Permalink
fix(generic_family): fix RenameGeneric command for non-string data types
Browse files Browse the repository at this point in the history
fixes dragonflydb#3107, fixes dragonflydb#3113, fixes dragonflydb#307

Signed-off-by: Stepan Bagritsevich <[email protected]>
  • Loading branch information
BagritsevichStepan committed Jun 18, 2024
1 parent e45c1e9 commit 061ff1a
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 115 deletions.
247 changes: 137 additions & 110 deletions src/server/generic_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

#include <optional>

#include "absl/types/span.h"
#include "facade/reply_builder.h"
#include "glog/logging.h"
#include "server/table.h"

extern "C" {
#include "redis/crc64.h"
Expand Down Expand Up @@ -145,13 +148,18 @@ bool RdbRestoreValue::Add(std::string_view data, std::string_view key, DbSlice&
return res.ok();
}

class Renamer;

class RestoreArgs {
private:
static constexpr int64_t NO_EXPIRATION = 0;

int64_t expiration_ = NO_EXPIRATION;
bool abs_time_ = false;
bool replace_ = false; // if true, over-ride existing key

friend class Renamer;

public:
constexpr bool Replace() const {
return replace_;
Expand Down Expand Up @@ -243,161 +251,182 @@ OpResult<RestoreArgs> RestoreArgs::TryFrom(const CmdArgList& args) {

OpStatus OpPersist(const OpArgs& op_args, string_view key);

OpResult<std::string> OpDump(const OpArgs& op_args, string_view key);

OpResult<bool> OnRestore(const OpArgs& op_args, std::string_view key, std::string_view payload,
RestoreArgs restore_args, int rdb_version);

class Renamer {
public:
Renamer(ShardId source_id) : src_sid_(source_id) {
Renamer(std::string_view src_key, std::string_view dest_key, unsigned shard_count)
: src_key_(src_key),
dest_key_(dest_key),
src_sid_(Shard(src_key, shard_count)),
dest_sid_(Shard(dest_key, shard_count)) {
}

void Find(Transaction* t);
void Initialize(Transaction* t);

OpResult<void> status() const {
return status_;
};

void Finalize(Transaction* t, bool skip_exist_dest);
void Finalize(Transaction* t, bool destination_should_not_exist);

private:
OpStatus MoveSrc(Transaction* t, EngineShard* es);
OpStatus UpdateDest(Transaction* t, EngineShard* es);
bool KeyExists(Transaction* t, EngineShard* shard, std::string_view key);
OpStatus SerializeSrc(Transaction* t, EngineShard* shard);

ShardId src_sid_;
OpStatus DelSrc(Transaction* t, EngineShard* shard);
OpStatus DeserializeDest(Transaction* t, EngineShard* shard);

struct FindResult {
string_view key;
PrimeValue ref_val;
uint64_t expire_ts;
struct SerializedValue {
std::string value;
int64_t expire_ts;
bool sticky;
bool found = false;
};

PrimeValue pv_;
string str_val_;
private:
const std::string_view src_key_;
const std::string_view dest_key_;
const ShardId src_sid_;
const ShardId dest_sid_;

bool src_found_ = false;
bool dest_found_ = false;

SerializedValue serialized_value_;

FindResult src_res_, dest_res_; // index 0 for source, 1 for destination
OpResult<void> status_;
};

void Renamer::Find(Transaction* t) {
void Renamer::Initialize(Transaction* t) {
auto cb = [this](Transaction* t, EngineShard* shard) {
auto args = t->GetShardArgs(shard->shard_id());
DCHECK_EQ(1u, args.Size());

FindResult* res = (shard->shard_id() == src_sid_) ? &src_res_ : &dest_res_;
const ShardId shard_id = shard->shard_id();

res->key = args.Front();
auto& db_slice = EngineShard::tlocal()->db_slice();
auto [it, exp_it] = db_slice.FindReadOnly(t->GetDbContext(), res->key);
if (shard_id == src_sid_) {
src_found_ = KeyExists(t, shard, src_key_);
if (src_found_) {
return SerializeSrc(t, shard);
}
}

res->found = IsValid(it);
if (res->found) {
res->ref_val = it->second.AsRef();
res->expire_ts = db_slice.ExpireTime(exp_it);
res->sticky = it->first.IsSticky();
if (shard_id == dest_sid_) {
dest_found_ = KeyExists(t, shard, dest_key_);
}

return OpStatus::OK;
};

t->Execute(std::move(cb), false);
};

void Renamer::Finalize(Transaction* t, bool skip_exist_dest) {
if (!src_res_.found) {
void Renamer::Finalize(Transaction* t, bool destination_should_not_exist) {
if (!src_found_) {
status_ = OpStatus::KEY_NOTFOUND;
t->Conclude();
return;
}

if (dest_res_.found && skip_exist_dest) {
if (dest_found_ && destination_should_not_exist) {
status_ = OpStatus::KEY_EXISTS;
t->Conclude();
return;
}

DCHECK(src_res_.ref_val.IsRef());

// Src key exist and we need to override the destination.
// Alternatively, we could apply an optimistic algorithm and move src at Find step.
// We would need to restore the state in case of cleanups.
t->Execute([&](Transaction* t, EngineShard* shard) { return MoveSrc(t, shard); }, false);
t->Execute([&](Transaction* t, EngineShard* shard) { return UpdateDest(t, shard); }, true);
}
auto cb = [this](Transaction* t, EngineShard* shard) {
const ShardId shard_id = shard->shard_id();

OpStatus Renamer::MoveSrc(Transaction* t, EngineShard* es) {
if (es->shard_id() == src_sid_) { // Handle source key.
auto res = es->db_slice().FindMutable(t->GetDbContext(), src_res_.key);
auto& it = res.it;
CHECK(IsValid(it));

// We distinguish because of the SmallString that is pinned to its thread by design,
// thus can not be accessed via another thread.
// Therefore, we copy it to standard string in its thread.
if (it->second.ObjType() == OBJ_STRING) {
it->second.GetString(&str_val_);
} else {
bool has_expire = it->second.HasExpire();
pv_ = std::move(it->second);
it->second.SetExpire(has_expire);
if (shard_id == src_sid_) {
return DelSrc(t, shard);
}

res.post_updater.Run();
CHECK(es->db_slice().Del(t->GetDbIndex(), it)); // delete the entry with empty value in it.
if (es->journal()) {
RecordJournal(t->GetOpArgs(es), "DEL", ArgSlice{src_res_.key}, 2);
if (shard_id == dest_sid_) {
return DeserializeDest(t, shard);
}

return OpStatus::OK;
};

t->Execute(std::move(cb), true);
}

bool Renamer::KeyExists(Transaction* t, EngineShard* shard, std::string_view key) {
auto& db_slice = shard->db_slice();
auto it = db_slice.FindReadOnly(t->GetDbContext(), key).it;
return IsValid(it);
}

OpStatus Renamer::SerializeSrc(Transaction* t, EngineShard* shard) {
auto dump_res = OpDump(t->GetOpArgs(shard), src_key_);

RETURN_ON_BAD_STATUS(dump_res);

auto& db_slice = shard->db_slice();
auto [it, exp_it] = db_slice.FindReadOnly(t->GetDbContext(), src_key_);

serialized_value_.value = std::move(dump_res.value());
serialized_value_.expire_ts = db_slice.ExpireTime(exp_it);
serialized_value_.sticky = it->first.IsSticky();

return OpStatus::OK;
}

OpStatus Renamer::DelSrc(Transaction* t, EngineShard* shard) {
auto res = shard->db_slice().FindMutable(t->GetDbContext(), src_key_);
auto& it = res.it;

CHECK(IsValid(it));

res.post_updater.Run();
CHECK(shard->db_slice().Del(t->GetDbIndex(), it));
if (shard->journal()) {
RecordJournal(t->GetOpArgs(shard), "DEL", ArgSlice{src_key_}, 2);
}

return OpStatus::OK;
}

OpStatus Renamer::UpdateDest(Transaction* t, EngineShard* es) {
if (es->shard_id() != src_sid_) {
auto& db_slice = es->db_slice();
string_view dest_key = dest_res_.key;
auto res = db_slice.FindMutable(t->GetDbContext(), dest_key);
auto& dest_it = res.it;
bool is_prior_list = false;

if (IsValid(dest_it)) {
bool has_expire = dest_it->second.HasExpire();
is_prior_list = dest_it->second.ObjType() == OBJ_LIST;

if (src_res_.ref_val.ObjType() == OBJ_STRING) {
dest_it->second.SetString(str_val_);
} else {
dest_it->second = std::move(pv_);
}
dest_it->second.SetExpire(has_expire); // preserve expire flag.
db_slice.UpdateExpire(t->GetDbIndex(), dest_it, src_res_.expire_ts);
} else {
if (src_res_.ref_val.ObjType() == OBJ_STRING) {
pv_.SetString(str_val_);
}
auto op_res =
db_slice.AddNew(t->GetDbContext(), dest_key, std::move(pv_), src_res_.expire_ts);
RETURN_ON_BAD_STATUS(op_res);
res = std::move(*op_res);
}
OpStatus Renamer::DeserializeDest(Transaction* t, EngineShard* shard) {
auto& db_slice = shard->db_slice();
auto res = db_slice.FindMutable(t->GetDbContext(), dest_key_);

dest_it->first.SetSticky(src_res_.sticky);
auto& dest_it = res.it;
const bool is_prior_list = IsValid(dest_it) && dest_it->second.ObjType() == OBJ_LIST;

if (!is_prior_list && dest_it->second.ObjType() == OBJ_LIST && es->blocking_controller()) {
es->blocking_controller()->AwakeWatched(t->GetDbIndex(), dest_key);
}
if (es->journal()) {
OpArgs op_args = t->GetOpArgs(es);
string scratch;
// todo insert under multi exec
RecordJournal(op_args, "SET"sv, ArgSlice{dest_key, dest_it->second.GetSlice(&scratch)}, 2,
true);
if (dest_it->first.IsSticky()) {
RecordJournal(op_args, "STICK"sv, ArgSlice{dest_key}, 2, true);
}
if (dest_it->second.HasExpire()) {
auto time = absl::StrCat(src_res_.expire_ts);
RecordJournal(op_args, "PEXPIREAT"sv, ArgSlice{dest_key, time}, 2, true);
}
RecordJournalFinish(op_args, 2);
int rdb_version = 0;
CHECK(VerifyFooter(serialized_value_.value, &rdb_version));

RestoreArgs restore_args;
restore_args.expiration_ = serialized_value_.expire_ts;
restore_args.abs_time_ = true;
restore_args.replace_ = true;

auto restore_res =
OnRestore(t->GetOpArgs(shard), dest_key_, serialized_value_.value, restore_args, rdb_version);
RETURN_ON_BAD_STATUS(restore_res);

dest_it = db_slice.FindMutable(t->GetDbContext(), dest_key_).it;
dest_it->first.SetSticky(serialized_value_.sticky);

if (!is_prior_list && dest_it->second.ObjType() == OBJ_LIST && shard->blocking_controller()) {
shard->blocking_controller()->AwakeWatched(t->GetDbIndex(), dest_key_);
}

if (shard->journal()) {
OpArgs op_args = t->GetOpArgs(shard);

auto time = absl::StrCat(serialized_value_.expire_ts);
RecordJournal(op_args, "RESTORE"sv,
ArgSlice{dest_key_, time, serialized_value_.value, "REPLACE"sv, "ABSTTL"sv}, 2,
true);
if (dest_it->first.IsSticky()) {
RecordJournal(op_args, "STICK"sv, ArgSlice{dest_key_}, 2, true);
}
RecordJournalFinish(op_args, 2);
}

return OpStatus::OK;
Expand Down Expand Up @@ -1332,7 +1361,7 @@ void GenericFamily::Time(CmdArgList args, ConnectionContext* cntx) {
rb->SendLong(now_usec % 1000000);
}

OpResult<void> GenericFamily::RenameGeneric(CmdArgList args, bool skip_exist_dest,
OpResult<void> GenericFamily::RenameGeneric(CmdArgList args, bool destination_should_not_exist,
ConnectionContext* cntx) {
string_view key[2] = {ArgS(args, 0), ArgS(args, 1)};

Expand All @@ -1341,21 +1370,19 @@ OpResult<void> GenericFamily::RenameGeneric(CmdArgList args, bool skip_exist_des
if (transaction->GetUniqueShardCnt() == 1) {
transaction->ReviveAutoJournal(); // Safe to use RENAME with single shard
auto cb = [&](Transaction* t, EngineShard* shard) {
return OpRen(t->GetOpArgs(shard), key[0], key[1], skip_exist_dest);
return OpRen(t->GetOpArgs(shard), key[0], key[1], destination_should_not_exist);
};
OpResult<void> result = transaction->ScheduleSingleHopT(std::move(cb));

return result;
}

unsigned shard_count = shard_set->size();
Renamer renamer{Shard(key[0], shard_count)};

// Phase 1 -> Fetch keys from both shards.
// Phase 2 -> If everything is ok, clone the source object, delete the destination object, and
// set its ptr to cloned one. we also copy the expiration data of the source key.
renamer.Find(transaction);
renamer.Finalize(transaction, skip_exist_dest);
Renamer renamer{key[0], key[1], shard_count};

renamer.Initialize(transaction);
renamer.Finalize(transaction, destination_should_not_exist);

return renamer.status();
}
Expand Down Expand Up @@ -1422,7 +1449,7 @@ OpResult<uint32_t> GenericFamily::OpExists(const OpArgs& op_args, const ShardArg
}

OpResult<void> GenericFamily::OpRen(const OpArgs& op_args, string_view from_key, string_view to_key,
bool skip_exists) {
bool destination_should_not_exist) {
auto* es = op_args.shard;
auto& db_slice = es->db_slice();
auto from_res = db_slice.FindMutable(op_args.db_cntx, from_key);
Expand All @@ -1435,7 +1462,7 @@ OpResult<void> GenericFamily::OpRen(const OpArgs& op_args, string_view from_key,
bool is_prior_list = false;
auto to_res = db_slice.FindMutable(op_args.db_cntx, to_key);
if (IsValid(to_res.it)) {
if (skip_exists)
if (destination_should_not_exist)
return OpStatus::KEY_EXISTS;

is_prior_list = (to_res.it->second.ObjType() == OBJ_LIST);
Expand Down
2 changes: 1 addition & 1 deletion src/server/generic_family.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class GenericFamily {
static void RandomKey(CmdArgList args, ConnectionContext* cntx);
static void FieldTtl(CmdArgList args, ConnectionContext* cntx);

static OpResult<void> RenameGeneric(CmdArgList args, bool skip_exist_dest,
static OpResult<void> RenameGeneric(CmdArgList args, bool destination_should_not_exist,
ConnectionContext* cntx);
static void TtlGeneric(CmdArgList args, ConnectionContext* cntx, TimeUnit unit);

Expand Down
8 changes: 4 additions & 4 deletions tests/dragonfly/replication_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,16 +689,16 @@ async def check_expire(key):

await c_master.set("renamekey", "1000", px=50000)
await skip_cmd()
# Check RENAME turns into DEL SET and PEXPIREAT
# Check RENAME turns into DEL and RESTORE
await check_list_ooo(
"RENAME renamekey renamed",
[r"DEL renamekey", r"SET renamed 1000", r"PEXPIREAT renamed (.*?)"],
[r"DEL renamekey", r"RESTORE renamed (.*?) (.*?) REPLACE ABSTTL"],
)
await check_expire("renamed")
# Check RENAMENX turns into DEL SET and PEXPIREAT
# Check RENAMENX turns into DEL and RESTORE
await check_list_ooo(
"RENAMENX renamed renamekey",
[r"DEL renamed", r"SET renamekey 1000", r"PEXPIREAT renamekey (.*?)"],
[r"DEL renamed", r"RESTORE renamekey (.*?) (.*?) REPLACE ABSTTL"],
)
await check_expire("renamekey")

Expand Down

0 comments on commit 061ff1a

Please sign in to comment.