From 893c741c14b181cdebaeb53814ce7283effbf076 Mon Sep 17 00:00:00 2001 From: Vladislav Date: Tue, 22 Nov 2022 19:17:31 +0300 Subject: [PATCH] feat(server): Replication errors & cancellation (#501) --- src/server/common.h | 77 +++++++++ src/server/dflycmd.cc | 244 ++++++++++++++++------------ src/server/dflycmd.h | 106 +++++++++--- src/server/engine_shard_set.h | 17 +- src/server/journal/journal_slice.cc | 14 +- src/server/journal/journal_slice.h | 3 +- src/server/rdb_save.cc | 45 +++-- src/server/rdb_save.h | 6 +- src/server/replica.cc | 9 + src/server/server_family.cc | 6 +- src/server/snapshot.cc | 36 +++- src/server/snapshot.h | 17 +- src/server/transaction.cc | 2 +- tests/dragonfly/replication_test.py | 160 +++++++++++++----- 14 files changed, 543 insertions(+), 199 deletions(-) diff --git a/src/server/common.h b/src/server/common.h index e8447ca5394b..75dd77e4ddb3 100644 --- a/src/server/common.h +++ b/src/server/common.h @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -197,6 +198,82 @@ using AggregateStatus = AggregateValue; static_assert(facade::OpStatus::OK == facade::OpStatus{}, "Default intitialization should be OK value"); +// Re-usable component for signaling cancellation. +// Simple wrapper around atomic flag. +struct Cancellation { + void Cancel() { + flag_.store(true, std::memory_order_relaxed); + } + + bool IsCancelled() const { + return flag_.load(std::memory_order_relaxed); + } + + private: + std::atomic_bool flag_; +}; + +// Error wrapper, that stores error_code and optional string message. +class GenericError { + public: + GenericError() = default; + GenericError(std::error_code ec) : ec_{ec}, details_{} { + } + GenericError(std::error_code ec, std::string details) : ec_{ec}, details_{std::move(details)} { + } + + std::pair Get() const { + return {ec_, details_}; + } + + std::error_code GetError() const { + return ec_; + } + + const std::string& GetDetails() const { + return details_; + } + + operator bool() const { + return bool(ec_); + } + + private: + std::error_code ec_; + std::string details_; +}; + +using AggregateGenericError = AggregateValue; + +// Contest combines Cancellation and AggregateGenericError in one class. +// Allows setting an error_handler to run on errors. +class Context : public Cancellation { + public: + // The error handler should return false if this error is ignored. + using ErrHandler = std::function; + + Context() = default; + Context(ErrHandler err_handler) : Cancellation{}, err_handler_{std::move(err_handler)} { + } + + template void Error(T... ts) { + std::lock_guard lk{mu_}; + if (err_) + return; + + GenericError new_err{std::forward(ts)...}; + if (!err_handler_ || err_handler_(new_err)) { + err_ = std::move(new_err); + Cancel(); + } + } + + private: + GenericError err_; + ErrHandler err_handler_; + ::boost::fibers::mutex mu_; +}; + struct ScanOpts { std::string_view pattern; size_t limit = 10; diff --git a/src/server/dflycmd.cc b/src/server/dflycmd.cc index 6531d27997a8..bbdb0a51f6fd 100644 --- a/src/server/dflycmd.cc +++ b/src/server/dflycmd.cc @@ -98,26 +98,6 @@ void DflyCmd::Run(CmdArgList args, ConnectionContext* cntx) { rb->SendError(kSyntaxErr); } -void DflyCmd::OnClose(ConnectionContext* cntx) { - unsigned session_id = cntx->conn_state.repl_session_id; - unsigned flow_id = cntx->conn_state.repl_flow_id; - - if (!session_id) - return; - - if (flow_id == kuint32max) { - DeleteSyncSession(session_id); - } else { - shared_ptr sync_info = GetSyncInfo(session_id); - if (sync_info) { - lock_guard lk(sync_info->mu); - if (sync_info->state != SyncState::CANCELLED) { - UnregisterFlow(&sync_info->flows[flow_id]); - } - } - } -} - void DflyCmd::Journal(CmdArgList args, ConnectionContext* cntx) { DCHECK_GE(args.size(), 3u); ToUpper(&args[2]); @@ -227,12 +207,12 @@ void DflyCmd::Flow(CmdArgList args, ConnectionContext* cntx) { return rb->SendError(facade::kInvalidIntErr); } - auto [sync_id, sync_info] = GetSyncInfoOrReply(sync_id_str, rb); + auto [sync_id, replica_ptr] = GetReplicaInfoOrReply(sync_id_str, rb); if (!sync_id) return; - unique_lock lk(sync_info->mu); - if (sync_info->state != SyncState::PREPARATION) + unique_lock lk(replica_ptr->mu); + if (replica_ptr->state != SyncState::PREPARATION) return rb->SendError(kInvalidState); // Set meta info on connection. @@ -243,7 +223,7 @@ void DflyCmd::Flow(CmdArgList args, ConnectionContext* cntx) { absl::InsecureBitGen gen; string eof_token = GetRandomHex(gen, 40); - sync_info->flows[flow_id] = FlowInfo{cntx->owner(), eof_token}; + replica_ptr->flows[flow_id] = FlowInfo{cntx->owner(), eof_token}; listener_->Migrate(cntx->owner(), shard_set->pool()->at(flow_id)); rb->StartArray(2); @@ -257,12 +237,12 @@ void DflyCmd::Sync(CmdArgList args, ConnectionContext* cntx) { VLOG(1) << "Got DFLY SYNC " << sync_id_str; - auto [sync_id, sync_info] = GetSyncInfoOrReply(sync_id_str, rb); + auto [sync_id, replica_ptr] = GetReplicaInfoOrReply(sync_id_str, rb); if (!sync_id) return; - unique_lock lk(sync_info->mu); - if (!CheckReplicaStateOrReply(*sync_info, SyncState::PREPARATION, rb)) + unique_lock lk(replica_ptr->mu); + if (!CheckReplicaStateOrReply(*replica_ptr, SyncState::PREPARATION, rb)) return; // Start full sync. @@ -270,8 +250,9 @@ void DflyCmd::Sync(CmdArgList args, ConnectionContext* cntx) { TransactionGuard tg{cntx->transaction}; AggregateStatus status; - auto cb = [this, &status, sync_info = sync_info](unsigned index, auto*) { - status = StartFullSyncInThread(&sync_info->flows[index], EngineShard::tlocal()); + auto cb = [this, &status, replica_ptr](unsigned index, auto*) { + status = StartFullSyncInThread(&replica_ptr->flows[index], &replica_ptr->cntx, + EngineShard::tlocal()); }; shard_set->pool()->AwaitFiberOnAll(std::move(cb)); @@ -280,7 +261,7 @@ void DflyCmd::Sync(CmdArgList args, ConnectionContext* cntx) { return rb->SendError(kInvalidState); } - sync_info->state = SyncState::FULL_SYNC; + replica_ptr->state = SyncState::FULL_SYNC; return rb->SendOk(); } @@ -290,20 +271,24 @@ void DflyCmd::StartStable(CmdArgList args, ConnectionContext* cntx) { VLOG(1) << "Got DFLY STARTSTABLE " << sync_id_str; - auto [sync_id, sync_info] = GetSyncInfoOrReply(sync_id_str, rb); + auto [sync_id, replica_ptr] = GetReplicaInfoOrReply(sync_id_str, rb); if (!sync_id) return; - unique_lock lk(sync_info->mu); - if (!CheckReplicaStateOrReply(*sync_info, SyncState::FULL_SYNC, rb)) + unique_lock lk(replica_ptr->mu); + if (!CheckReplicaStateOrReply(*replica_ptr, SyncState::FULL_SYNC, rb)) return; { TransactionGuard tg{cntx->transaction}; AggregateStatus status; - auto cb = [this, &status, sync_info = sync_info](unsigned index, auto*) { - status = StartStableSyncInThread(&sync_info->flows[index], EngineShard::tlocal()); + auto cb = [this, &status, replica_ptr](unsigned index, auto*) { + EngineShard* shard = EngineShard::tlocal(); + FlowInfo* flow = &replica_ptr->flows[index]; + + StopFullSyncInThread(flow, shard); + status = StartStableSyncInThread(flow, shard); return OpStatus::OK; }; shard_set->pool()->AwaitFiberOnAll(std::move(cb)); @@ -312,7 +297,7 @@ void DflyCmd::StartStable(CmdArgList args, ConnectionContext* cntx) { return rb->SendError(kInvalidState); } - sync_info->state = SyncState::STABLE_SYNC; + replica_ptr->state = SyncState::STABLE_SYNC; return rb->SendOk(); } @@ -326,49 +311,64 @@ void DflyCmd::Expire(CmdArgList args, ConnectionContext* cntx) { return rb->SendOk(); } -OpStatus DflyCmd::StartFullSyncInThread(FlowInfo* flow, EngineShard* shard) { - DCHECK(!flow->fb.joinable()); +OpStatus DflyCmd::StartFullSyncInThread(FlowInfo* flow, Context* cntx, EngineShard* shard) { + DCHECK(!flow->full_sync_fb.joinable()); SaveMode save_mode = shard == nullptr ? SaveMode::SUMMARY : SaveMode::SINGLE_SHARD; flow->saver.reset(new RdbSaver(flow->conn->socket(), save_mode, false)); + flow->cleanup = [flow]() { + flow->saver->Cancel(); + flow->TryShutdownSocket(); + }; + // Shard can be null for io thread. if (shard != nullptr) { - auto ec = sf_->journal()->OpenInThread(false, string_view()); - CHECK(!ec); - flow->saver->StartSnapshotInShard(true, shard); + CHECK(!sf_->journal()->OpenInThread(false, ""sv)); // can only happen in persistent mode. + flow->saver->StartSnapshotInShard(true, cntx, shard); } - flow->fb = ::boost::fibers::fiber(&DflyCmd::FullSyncFb, this, flow); + flow->full_sync_fb = ::boost::fibers::fiber(&DflyCmd::FullSyncFb, this, flow, cntx); return OpStatus::OK; } -OpStatus DflyCmd::StartStableSyncInThread(FlowInfo* flow, EngineShard* shard) { +void DflyCmd::StopFullSyncInThread(FlowInfo* flow, EngineShard* shard) { // Shard can be null for io thread. if (shard != nullptr) { flow->saver->StopSnapshotInShard(shard); } // Wait for full sync to finish. - if (flow->fb.joinable()) { - flow->fb.join(); + if (flow->full_sync_fb.joinable()) { + flow->full_sync_fb.join(); } - if (shard != nullptr) { - flow->saver.reset(); + // Reset cleanup and saver + flow->cleanup = []() {}; + flow->saver.reset(); +} - // TODO: Add cancellation. - auto cb = sf_->journal()->RegisterOnChange([flow](const journal::Entry& je) { +OpStatus DflyCmd::StartStableSyncInThread(FlowInfo* flow, EngineShard* shard) { + // Register journal listener and cleanup. + uint32_t cb_id = 0; + if (shard != nullptr) { + cb_id = sf_->journal()->RegisterOnChange([flow](const journal::Entry& je) { // TODO: Serialize event. ReqSerializer serializer{flow->conn->socket()}; serializer.SendCommand(absl::StrCat("SET ", je.key, " ", je.pval_ptr->ToString())); }); } + flow->cleanup = [flow, this, cb_id]() { + if (cb_id) + sf_->journal()->Unregister(cb_id); + flow->TryShutdownSocket(); + }; + return OpStatus::OK; } -void DflyCmd::FullSyncFb(FlowInfo* flow) { +void DflyCmd::FullSyncFb(FlowInfo* flow, Context* cntx) { error_code ec; RdbSaver* saver = flow->saver.get(); @@ -380,92 +380,121 @@ void DflyCmd::FullSyncFb(FlowInfo* flow) { } if (ec) { - LOG(ERROR) << ec; - return; + return cntx->Error(ec); } - // TODO: we should be able to stop earlier if requested. - ec = saver->SaveBody(nullptr); - if (ec) { - LOG(ERROR) << ec; - return; + if (ec = saver->SaveBody(cntx, nullptr); ec) { + return cntx->Error(ec); } - VLOG(1) << "Sending full sync EOF"; - ec = flow->conn->socket()->Write(io::Buffer(flow->eof_token)); if (ec) { - LOG(ERROR) << ec; - return; + return cntx->Error(ec); } } uint32_t DflyCmd::CreateSyncSession() { unique_lock lk(mu_); + unsigned sync_id = next_sync_id_++; - auto sync_info = make_shared(); - sync_info->flows.resize(shard_set->size() + 1); + unsigned flow_count = shard_set->size() + 1; + auto err_handler = [this, sync_id](const GenericError& err) { + LOG(INFO) << "Replication error: " << err.GetError().message() << " " << err.GetDetails(); - auto [it, inserted] = sync_infos_.emplace(next_sync_id_, std::move(sync_info)); + // Stop replication in case of error. + // StopReplication needs to run async to prevent blocking + // the error handler. + ::boost::fibers::fiber{&DflyCmd::StopReplication, this, sync_id}.detach(); + + return true; // Cancel context + }; + + auto replica_ptr = make_shared(flow_count, std::move(err_handler)); + auto [it, inserted] = replica_infos_.emplace(sync_id, std::move(replica_ptr)); CHECK(inserted); - return next_sync_id_++; + return sync_id; } -void DflyCmd::UnregisterFlow(FlowInfo* flow) { - // TODO: Cancel saver operations. - flow->conn = nullptr; - flow->saver.reset(); -} +void DflyCmd::OnClose(ConnectionContext* cntx) { + unsigned session_id = cntx->conn_state.repl_session_id; + if (!session_id) + return; -void DflyCmd::DeleteSyncSession(uint32_t sync_id) { - shared_ptr sync_info; + auto replica_ptr = GetReplicaInfo(session_id); + if (!replica_ptr) + return; - // Remove sync_info from map. - // Store by value to keep alive. - { - unique_lock lk(mu_); + // Because CancelReplication holds the per-replica mutex, + // aborting connection will block here until cancellation finishes. + // This allows keeping resources alive during the cleanup phase. + CancelReplication(session_id, replica_ptr); +} - auto it = sync_infos_.find(sync_id); - if (it == sync_infos_.end()) - return; +void DflyCmd::StopReplication(uint32_t sync_id) { + auto replica_ptr = GetReplicaInfo(sync_id); + if (!replica_ptr) + return; - sync_info = it->second; - sync_infos_.erase(it); - } + CancelReplication(sync_id, replica_ptr); +} - // Wait for all operations to finish. - // Set state to CANCELLED so no other operations will run. - { - unique_lock lk(sync_info->mu); - sync_info->state = SyncState::CANCELLED; +void DflyCmd::CancelReplication(uint32_t sync_id, shared_ptr replica_ptr) { + lock_guard lk(replica_ptr->mu); + if (replica_ptr->state == SyncState::CANCELLED) { + return; } - // Try to cleanup flows. - for (auto& flow : sync_info->flows) { - if (flow.conn != nullptr) { - VLOG(1) << "Flow connection " << flow.conn->GetName() << " is still alive" - << " on sync_id " << sync_id; + LOG(INFO) << "Cancelling sync session " << sync_id; + + // Update replica_ptr state and cancel context. + replica_ptr->state = SyncState::CANCELLED; + replica_ptr->cntx.Cancel(); + + // Run cleanup for shard threads. + shard_set->AwaitRunningOnShardQueue([replica_ptr](EngineShard* shard) { + FlowInfo* flow = &replica_ptr->flows[shard->shard_id()]; + if (flow->cleanup) { + flow->cleanup(); } - // TODO: Implement cancellation. - if (flow.fb.joinable()) { - VLOG(1) << "Force joining fiber on on sync_id " << sync_id; - flow.fb.join(); + }); + + // Wait for tasks to finish. + shard_set->pool()->AwaitFiberOnAll([replica_ptr](unsigned index, auto*) { + FlowInfo* flow = &replica_ptr->flows[index]; + + // Cleanup hasn't been run for io-thread. + if (EngineShard::tlocal() == nullptr) { + if (flow->cleanup) { + flow->cleanup(); + } + } + + if (flow->full_sync_fb.joinable()) { + flow->full_sync_fb.join(); } + }); + + // Remove ReplicaInfo from global map + { + lock_guard lk(mu_); + replica_infos_.erase(sync_id); } + + LOG(INFO) << "Evicted sync session " << sync_id; } -shared_ptr DflyCmd::GetSyncInfo(uint32_t sync_id) { +shared_ptr DflyCmd::GetReplicaInfo(uint32_t sync_id) { unique_lock lk(mu_); - auto it = sync_infos_.find(sync_id); - if (it != sync_infos_.end()) + auto it = replica_infos_.find(sync_id); + if (it != replica_infos_.end()) return it->second; return {}; } -pair> DflyCmd::GetSyncInfoOrReply(std::string_view id_str, - RedisReplyBuilder* rb) { +pair> DflyCmd::GetReplicaInfoOrReply( + std::string_view id_str, RedisReplyBuilder* rb) { unique_lock lk(mu_); uint32_t sync_id; @@ -474,8 +503,8 @@ pair> DflyCmd::GetSyncInfoOrReply(std::s return {0, nullptr}; } - auto sync_it = sync_infos_.find(sync_id); - if (sync_it == sync_infos_.end()) { + auto sync_it = replica_infos_.find(sync_id); + if (sync_it == replica_infos_.end()) { rb->SendError(kIdNotFound); return {0, nullptr}; } @@ -483,7 +512,7 @@ pair> DflyCmd::GetSyncInfoOrReply(std::s return {sync_id, sync_it->second}; } -bool DflyCmd::CheckReplicaStateOrReply(const SyncInfo& sync_info, SyncState expected, +bool DflyCmd::CheckReplicaStateOrReply(const ReplicaInfo& sync_info, SyncState expected, RedisReplyBuilder* rb) { if (sync_info.state != expected) { rb->SendError(kInvalidState); @@ -506,4 +535,11 @@ void DflyCmd::BreakOnShutdown() { VLOG(1) << "BreakOnShutdown"; } +void DflyCmd::FlowInfo::TryShutdownSocket() { + // Close socket for clean disconnect. + if (conn->socket()->IsOpen()) { + conn->socket()->Shutdown(SHUT_RDWR); + } +} + } // namespace dfly diff --git a/src/server/dflycmd.h b/src/server/dflycmd.h index abbfbf5ab4b4..f05a4c9bfae2 100644 --- a/src/server/dflycmd.h +++ b/src/server/dflycmd.h @@ -5,9 +5,11 @@ #pragma once #include -#include +#include #include +#include +#include #include "server/conn_context.h" @@ -29,29 +31,84 @@ namespace journal { class Journal; } // namespace journal +// DflyCmd is responsible for managing replication. A master instance can be connected +// to many replica instances, what is more, each of them can open multiple connections. +// This is why its important to understand replica lifecycle management before making +// any crucial changes. +// +// A ReplicaInfo instance is responsible for managing a replica's state and is accessible by its +// sync_id. Each per-thread connection is called a Flow and is represented by the FlowInfo +// instance, accessible by its index. +// +// An important aspect is synchronization and efficient locking. Two levels of locking are used: +// 1. Global locking. +// Member mutex `mu_` is used for synchronizing operations connected with internal data +// structures. +// 2. Per-replica locking +// ReplicaInfo contains a separate mutex that is used for replica-only routines. It is held +// during state transitions (start full sync, start stable state sync), cancellation and member +// access. +// +// Upon first connection from the replica, a new ReplicaInfo is created. +// It tranistions through the following phases: +// 1. Preparation +// During this start phase the "flows" are set up - one connection for every master thread. Those +// connections registered by the FLOW command sent from each newly opened connection. +// 2. Full sync +// This phase is initiated by the SYNC command. It makes sure all flows are connected and the +// replica is in a valid state. +// 3. Stable state sync +// After the replica has received confirmation, that each flow is ready to transition, it sends a +// STARTSTABLE command. This transitions the replica into streaming journal changes. +// 4. Cancellation +// This can happed due to an error at any phase or through a normal abort. For properly releasing +// resources we need to run a multi-step cancellation procedure: +// 1. Transition state +// We obtain the ReplicaInfo lock, transition into the cancelled state and cancel the context. +// 2. Joining tasks +// Running tasks will stop on receiving the cancellation flag. Each FlowInfo has also an +// optional cleanup handler, that is invoked after cancelling. This should allow recovering +// from any state. The flows task will be awaited and joined if present. +// 3. Unlocking the mutex +// Now that all tasks have finished and all cleanup handlers have run, we can safely release +// the per-replica mutex, so that all OnClose handlers will unblock and internal resources +// will be released by dragonfly. Then the ReplicaInfo is removed from the global map. +// +// class DflyCmd { public: + // See header comments for state descriptions. enum class SyncState { PREPARATION, FULL_SYNC, STABLE_SYNC, CANCELLED }; + // Stores information related to a single flow. struct FlowInfo { FlowInfo() = default; FlowInfo(facade::Connection* conn, const std::string& eof_token) - : conn(conn), eof_token(eof_token){}; + : conn{conn}, eof_token{eof_token} {}; + + // Shutdown associated socket if its still open. + void TryShutdownSocket(); facade::Connection* conn; - std::string eof_token; - std::unique_ptr saver; + ::boost::fibers::fiber full_sync_fb; // Full sync fiber. + std::unique_ptr saver; // Saver used by the full sync phase. + std::string eof_token; - ::boost::fibers::fiber fb; + std::function cleanup; // Optional cleanup for cancellation. }; - struct SyncInfo { - SyncState state = SyncState::PREPARATION; + // Stores information related to a single replica. + struct ReplicaInfo { + ReplicaInfo(unsigned flow_count, Context::ErrHandler err_handler) + : state{SyncState::PREPARATION}, cntx{std::move(err_handler)}, flows{flow_count} { + } - std::vector flows; + SyncState state; + Context cntx; - ::boost::fibers::mutex mu; // guard operations on replica. + std::vector flows; + ::boost::fibers::mutex mu; // See top of header for locking levels. }; public: @@ -93,39 +150,44 @@ class DflyCmd { void Expire(CmdArgList args, ConnectionContext* cntx); // Start full sync in thread. Start FullSyncFb. Called for each flow. - facade::OpStatus StartFullSyncInThread(FlowInfo* flow, EngineShard* shard); + facade::OpStatus StartFullSyncInThread(FlowInfo* flow, Context* cntx, EngineShard* shard); + + // Stop full sync in thread. Run state switch cleanup. + void StopFullSyncInThread(FlowInfo* flow, EngineShard* shard); // Start stable sync in thread. Called for each flow. facade::OpStatus StartStableSyncInThread(FlowInfo* flow, EngineShard* shard); // Fiber that runs full sync for each flow. - void FullSyncFb(FlowInfo* flow); + void FullSyncFb(FlowInfo* flow, Context* cntx); - // Unregister flow. Must be called when flow disconnects. - void UnregisterFlow(FlowInfo*); + // Main entrypoint for stopping replication. + void StopReplication(uint32_t sync_id); - // Delete sync session. Cleanup flows. - void DeleteSyncSession(uint32_t sync_id); + // Transition into cancelled state, run cleanup. + void CancelReplication(uint32_t sync_id, std::shared_ptr replica_info_ptr); - // Get SyncInfo by sync_id. - std::shared_ptr GetSyncInfo(uint32_t sync_id); + // Get ReplicaInfo by sync_id. + std::shared_ptr GetReplicaInfo(uint32_t sync_id); // Find sync info by id or send error reply. - std::pair> GetSyncInfoOrReply(std::string_view id, - facade::RedisReplyBuilder* rb); + std::pair> GetReplicaInfoOrReply( + std::string_view id, facade::RedisReplyBuilder* rb); - bool CheckReplicaStateOrReply(const SyncInfo& si, SyncState expected, + // Check replica is in expected state and flows are set-up correctly. + bool CheckReplicaStateOrReply(const ReplicaInfo& ri, SyncState expected, facade::RedisReplyBuilder* rb); + private: ServerFamily* sf_; util::ListenerInterface* listener_; TxId journal_txid_ = 0; - absl::btree_map> sync_infos_; uint32_t next_sync_id_ = 1; + absl::btree_map> replica_infos_; - ::boost::fibers::mutex mu_; // guard sync info and journal operations. + ::boost::fibers::mutex mu_; // Guard global operations. See header top for locking levels. }; } // namespace dfly diff --git a/src/server/engine_shard_set.h b/src/server/engine_shard_set.h index 9c82c8f73591..6436cdc6e9a9 100644 --- a/src/server/engine_shard_set.h +++ b/src/server/engine_shard_set.h @@ -225,11 +225,26 @@ class EngineShardSet { RunBriefInParallel(std::forward(func), [](auto i) { return true; }); } - // Runs a brief function on selected shards. Waits for it to complete. + // Runs a brief function on selected shard thread. Waits for it to complete. template void RunBriefInParallel(U&& func, P&& pred) const; template void RunBlockingInParallel(U&& func); + // Runs func on all shards via the same shard queue that's been used by transactions framework. + // The functions running inside the shard queue run atomically (sequentially) + // with respect each other on the same shard. + template void AwaitRunningOnShardQueue(U&& func) { + util::fibers_ext::BlockingCounter bc{unsigned(shard_queue_.size())}; + for (size_t i = 0; i < shard_queue_.size(); ++i) { + Add(i, [&func, bc]() mutable { + func(EngineShard::tlocal()); + bc.Dec(); + }); + } + + bc.Wait(); + } + // Used in tests void TEST_EnableHeartBeat(); void TEST_EnableCacheMode(); diff --git a/src/server/journal/journal_slice.cc b/src/server/journal/journal_slice.cc index af11226f405f..ff78a8a5eee6 100644 --- a/src/server/journal/journal_slice.cc +++ b/src/server/journal/journal_slice.cc @@ -118,9 +118,11 @@ error_code JournalSlice::Close() { void JournalSlice::AddLogRecord(const Entry& entry) { DCHECK(ring_buffer_); + iterating_cb_arr_ = true; for (const auto& k_v : change_cb_arr_) { k_v.second(entry); } + iterating_cb_arr_ = false; RingItem item; item.lsn = lsn_; @@ -146,12 +148,12 @@ uint32_t JournalSlice::RegisterOnChange(ChangeCallback cb) { } void JournalSlice::Unregister(uint32_t id) { - for (auto it = change_cb_arr_.begin(); it != change_cb_arr_.end(); ++it) { - if (it->first == id) { - change_cb_arr_.erase(it); - break; - } - } + CHECK(!iterating_cb_arr_); + + auto it = find_if(change_cb_arr_.begin(), change_cb_arr_.end(), + [id](const auto& e) { return e.first == id; }); + CHECK(it != change_cb_arr_.end()); + change_cb_arr_.erase(it); } } // namespace journal diff --git a/src/server/journal/journal_slice.h b/src/server/journal/journal_slice.h index e521337043f5..65dc1f74ee61 100644 --- a/src/server/journal/journal_slice.h +++ b/src/server/journal/journal_slice.h @@ -48,12 +48,13 @@ class JournalSlice { void Unregister(uint32_t); private: - struct RingItem; std::string shard_path_; std::unique_ptr shard_file_; std::optional> ring_buffer_; + + bool iterating_cb_arr_ = false; std::vector> change_cb_arr_; size_t file_offset_ = 0; diff --git a/src/server/rdb_save.cc b/src/server/rdb_save.cc index 890ea13ed8c2..91c6bdbe7fe0 100644 --- a/src/server/rdb_save.cc +++ b/src/server/rdb_save.cc @@ -739,11 +739,11 @@ class RdbSaver::Impl { // correct closing semantics - channel is closing when K producers marked it as closed. Impl(bool align_writes, unsigned producers_len, io::Sink* sink); - void StartSnapshotting(bool stream_journal, EngineShard* shard); + void StartSnapshotting(bool stream_journal, const Cancellation* cll, EngineShard* shard); void StopSnapshotting(EngineShard* shard); - error_code ConsumeChannel(); + error_code ConsumeChannel(const Cancellation* cll); error_code Flush() { if (aligned_buf_) @@ -764,6 +764,8 @@ class RdbSaver::Impl { return &meta_serializer_; } + void Cancel(); + private: unique_ptr& GetSnapshot(EngineShard* shard); @@ -797,7 +799,7 @@ error_code RdbSaver::Impl::SaveAuxFieldStrStr(string_view key, string_view val) return error_code{}; } -error_code RdbSaver::Impl::ConsumeChannel() { +error_code RdbSaver::Impl::ConsumeChannel(const Cancellation* cll) { error_code io_error; uint8_t buf[16]; @@ -812,10 +814,13 @@ error_code RdbSaver::Impl::ConsumeChannel() { auto& channel = channel_; while (channel.Pop(record)) { - if (io_error) + if (io_error || cll->IsCancelled()) continue; do { + if (cll->IsCancelled()) + continue; + if (record.db_index != last_db_index) { unsigned enclen = SerializeLen(record.db_index, buf + 1); string_view str{(char*)buf, enclen + 1}; @@ -855,17 +860,32 @@ error_code RdbSaver::Impl::ConsumeChannel() { return io_error; } -void RdbSaver::Impl::StartSnapshotting(bool stream_journal, EngineShard* shard) { +void RdbSaver::Impl::StartSnapshotting(bool stream_journal, const Cancellation* cll, + EngineShard* shard) { auto& s = GetSnapshot(shard); s.reset(new SliceSnapshot(&shard->db_slice(), &channel_)); - s->Start(stream_journal); + s->Start(stream_journal, cll); } void RdbSaver::Impl::StopSnapshotting(EngineShard* shard) { GetSnapshot(shard)->Stop(); } +void RdbSaver::Impl::Cancel() { + auto* shard = EngineShard::tlocal(); + if (!shard) + return; + + auto& snapshot = GetSnapshot(shard); + if (snapshot) + snapshot->Cancel(); + + dfly::SliceSnapshot::DbRecord rec; + while (channel_.Pop(rec)) { + } +} + void RdbSaver::Impl::FillFreqMap(RdbTypeFreqMap* dest) const { for (auto& ptr : shard_snapshots_) { const RdbTypeFreqMap& src_map = ptr->freq_map(); @@ -905,8 +925,9 @@ RdbSaver::RdbSaver(::io::Sink* sink, SaveMode save_mode, bool align_writes) { RdbSaver::~RdbSaver() { } -void RdbSaver::StartSnapshotInShard(bool stream_journal, EngineShard* shard) { - impl_->StartSnapshotting(stream_journal, shard); +void RdbSaver::StartSnapshotInShard(bool stream_journal, const Cancellation* cll, + EngineShard* shard) { + impl_->StartSnapshotting(stream_journal, cll, shard); } void RdbSaver::StopSnapshotInShard(EngineShard* shard) { @@ -924,14 +945,14 @@ error_code RdbSaver::SaveHeader(const StringVec& lua_scripts) { return error_code{}; } -error_code RdbSaver::SaveBody(RdbTypeFreqMap* freq_map) { +error_code RdbSaver::SaveBody(const Cancellation* cll, RdbTypeFreqMap* freq_map) { RETURN_ON_ERR(impl_->serializer()->FlushMem()); if (save_mode_ == SaveMode::SUMMARY) { impl_->serializer()->SendFullSyncCut(); } else { VLOG(1) << "SaveBody , snapshots count: " << impl_->Size(); - error_code io_error = impl_->ConsumeChannel(); + error_code io_error = impl_->ConsumeChannel(cll); if (io_error) { LOG(ERROR) << "io error " << io_error; return io_error; @@ -1001,4 +1022,8 @@ error_code RdbSaver::SaveAuxFieldStrInt(string_view key, int64_t val) { return impl_->SaveAuxFieldStrStr(key, string_view(buf, vlen)); } +void RdbSaver::Cancel() { + impl_->Cancel(); +} + } // namespace dfly diff --git a/src/server/rdb_save.h b/src/server/rdb_save.h index 02185eb3beab..b6b04f453dd0 100644 --- a/src/server/rdb_save.h +++ b/src/server/rdb_save.h @@ -72,7 +72,7 @@ class RdbSaver { // Initiates the serialization in the shard's thread. // TODO: to implement break functionality to allow stopping early. - void StartSnapshotInShard(bool stream_journal, EngineShard* shard); + void StartSnapshotInShard(bool stream_journal, const Cancellation* cll, EngineShard* shard); // Stops serialization in journal streaming mode in the shard's thread. void StopSnapshotInShard(EngineShard* shard); @@ -83,7 +83,9 @@ class RdbSaver { // Writes the RDB file into sink. Waits for the serialization to finish. // Fills freq_map with the histogram of rdb types. // freq_map can optionally be null. - std::error_code SaveBody(RdbTypeFreqMap* freq_map); + std::error_code SaveBody(const Cancellation* cll, RdbTypeFreqMap* freq_map); + + void Cancel(); SaveMode Mode() const { return save_mode_; diff --git a/src/server/replica.cc b/src/server/replica.cc index 3d894a0a2901..a91f7436046f 100644 --- a/src/server/replica.cc +++ b/src/server/replica.cc @@ -149,6 +149,15 @@ void Replica::Stop() { LOG_IF(ERROR, ec) << "Could not shutdown socket " << ec; }); } + + // Close sub flows. + auto partition = Partition(num_df_flows_); + shard_set->pool()->AwaitFiberOnAll([&](unsigned index, auto*) { + for (auto id : partition[index]) { + shard_flows_[id]->Stop(); + } + }); + if (sync_fb_.joinable()) sync_fb_.join(); } diff --git a/src/server/server_family.cc b/src/server/server_family.cc index a04edc018930..c9cf6a95a33e 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -195,6 +195,8 @@ class RdbSnapshot { std::unique_ptr io_sink_; std::unique_ptr saver_; RdbTypeFreqMap freq_map_; + + Cancellation cll_{}; }; io::Result LinuxWriteWrapper::WriteSome(const iovec* v, uint32_t len) { @@ -229,7 +231,7 @@ error_code RdbSnapshot::Start(SaveMode save_mode, const std::string& path, } error_code RdbSnapshot::SaveBody() { - return saver_->SaveBody(&freq_map_); + return saver_->SaveBody(&cll_, &freq_map_); } error_code RdbSnapshot::Close() { @@ -241,7 +243,7 @@ error_code RdbSnapshot::Close() { } void RdbSnapshot::StartInShard(EngineShard* shard) { - saver_->StartSnapshotInShard(false, shard); + saver_->StartSnapshotInShard(false, &cll_, shard); started_ = true; } diff --git a/src/server/snapshot.cc b/src/server/snapshot.cc index 66047a1bd973..8722c05ac984 100644 --- a/src/server/snapshot.cc +++ b/src/server/snapshot.cc @@ -34,7 +34,7 @@ SliceSnapshot::SliceSnapshot(DbSlice* slice, RecordChannel* dest) : db_slice_(sl SliceSnapshot::~SliceSnapshot() { } -void SliceSnapshot::Start(bool stream_journal) { +void SliceSnapshot::Start(bool stream_journal, const Cancellation* cll) { DCHECK(!snapshot_fb_.joinable()); auto on_change = [this](DbIndex db_index, const DbSlice::ChangeReq& req) { @@ -54,9 +54,11 @@ void SliceSnapshot::Start(bool stream_journal) { sfile_.reset(new io::StringFile); rdb_serializer_.reset(new RdbSerializer(sfile_.get())); - snapshot_fb_ = fiber([this, stream_journal] { - SerializeEntriesFb(); - if (!stream_journal) { + snapshot_fb_ = fiber([this, stream_journal, cll] { + SerializeEntriesFb(cll); + if (cll->IsCancelled()) { + Cancel(); + } else if (!stream_journal) { CloseRecordChannel(); } db_slice_->UnregisterOnChange(snapshot_version_); @@ -75,6 +77,14 @@ void SliceSnapshot::Stop() { CloseRecordChannel(); } +void SliceSnapshot::Cancel() { + CloseRecordChannel(); + if (journal_cb_id_) { + db_slice_->shard_owner()->journal()->Unregister(journal_cb_id_); + journal_cb_id_ = 0; + } +} + void SliceSnapshot::Join() { // Fiber could have already been joined by Stop. if (snapshot_fb_.joinable()) @@ -82,12 +92,15 @@ void SliceSnapshot::Join() { } // Serializes all the entries with version less than snapshot_version_. -void SliceSnapshot::SerializeEntriesFb() { +void SliceSnapshot::SerializeEntriesFb(const Cancellation* cll) { this_fiber::properties().set_name( absl::StrCat("SliceSnapshot", ProactorBase::GetIndex())); PrimeTable::Cursor cursor; for (DbIndex db_indx = 0; db_indx < db_array_.size(); ++db_indx) { + if (cll->IsCancelled()) + return; + if (!db_array_[db_indx]) continue; @@ -100,6 +113,9 @@ void SliceSnapshot::SerializeEntriesFb() { mu_.unlock(); do { + if (cll->IsCancelled()) + return; + PrimeTable::Cursor next = pt->Traverse(cursor, [this](auto it) { this->SaveCb(move(it)); }); cursor = next; @@ -126,7 +142,8 @@ void SliceSnapshot::SerializeEntriesFb() { mu_.lock(); mu_.unlock(); - CHECK(!rdb_serializer_->SendFullSyncCut()); + for (unsigned i = 10; i > 1; i--) + CHECK(!rdb_serializer_->SendFullSyncCut()); FlushSfile(true); VLOG(1) << "Exit SnapshotSerializer (serialized/side_saved/cbcalls): " << serialized_ << "/" @@ -138,7 +155,12 @@ void SliceSnapshot::CloseRecordChannel() { // Can not think of anything more elegant. mu_.lock(); mu_.unlock(); - dest_->StartClosing(); + + // Make sure we close the channel only once with a CAS check. + bool actual = false; + if (closed_chan_.compare_exchange_strong(actual, true)) { + dest_->StartClosing(); + } } // This function should not block and should not preempt because it's called diff --git a/src/server/snapshot.h b/src/server/snapshot.h index db2df0e0c868..a2d2c15d03b7 100644 --- a/src/server/snapshot.h +++ b/src/server/snapshot.h @@ -4,6 +4,7 @@ #pragma once +#include #include #include "io/file.h" @@ -36,12 +37,14 @@ class SliceSnapshot { SliceSnapshot(DbSlice* slice, RecordChannel* dest); ~SliceSnapshot(); - void Start(bool stream_journal); + void Start(bool stream_journal, const Cancellation* cll); - void Stop(); // only needs to be called in journal streaming mode. + void Stop(); // only needs to be called in journal streaming mode. void Join(); + void Cancel(); + uint64_t snapshot_version() const { return snapshot_version_; } @@ -61,8 +64,8 @@ class SliceSnapshot { private: void CloseRecordChannel(); - void SerializeEntriesFb(); - + void SerializeEntriesFb(const Cancellation* cll); + void SerializeSingleEntry(DbIndex db_index, const PrimeKey& pk, const PrimeValue& pv, RdbSerializer* serializer); @@ -89,15 +92,17 @@ class SliceSnapshot { // version upper bound for entries that should be saved (not included). uint64_t snapshot_version_ = 0; DbIndex savecb_current_db_; // used by SaveCb - + size_t channel_bytes_ = 0; size_t serialized_ = 0, skipped_ = 0, side_saved_ = 0, savecb_calls_ = 0; uint64_t rec_id_ = 0; uint32_t num_records_in_blob_ = 0; - + uint32_t journal_cb_id_ = 0; ::boost::fibers::fiber snapshot_fb_; + + std::atomic_bool closed_chan_{false}; }; } // namespace dfly diff --git a/src/server/transaction.cc b/src/server/transaction.cc index 6c52e38e3676..c4283109ab3c 100644 --- a/src/server/transaction.cc +++ b/src/server/transaction.cc @@ -843,7 +843,7 @@ bool Transaction::ScheduleUniqueShard(EngineShard* shard) { sd.pq_pos = shard->txq()->Insert(this); DCHECK_EQ(0, sd.local_mask & KEYLOCK_ACQUIRED); - bool lock_acquired = shard->db_slice().Acquire(mode, lock_args); + shard->db_slice().Acquire(mode, lock_args); sd.local_mask |= KEYLOCK_ACQUIRED; DVLOG(1) << "Rescheduling into TxQueue " << DebugId(); diff --git a/tests/dragonfly/replication_test.py b/tests/dragonfly/replication_test.py index f53e7e704243..3eb02b8d3de7 100644 --- a/tests/dragonfly/replication_test.py +++ b/tests/dragonfly/replication_test.py @@ -2,8 +2,11 @@ import pytest import asyncio import aioredis +import random +from itertools import count, chain, repeat from .utility import * +from . import dfly_args BASE_PORT = 1111 @@ -12,6 +15,10 @@ Test full replication pipeline. Test full sync with streaming changes and stable state streaming. """ +# 1. Number of master threads +# 2. Number of threads for each replica +# 3. Number of keys stored and sent in full sync +# 4. Number of keys overwritten during full sync replication_cases = [ (8, [8], 20000, 5000), (8, [8], 10000, 10000), @@ -80,61 +87,140 @@ async def check_replication(c_replica): """ -Test replica crash during full sync on multiple replicas without altering data during replication. +Test disconnecting replicas during different phases with constantly streaming changes to master. +Three types are tested: +1. Replicas crashing during full sync state +2. Replicas crashing during stable sync state +3. Replicas disconnecting normally with REPLICAOF NO ONE during stable state """ - -# (threads_master, threads_replicas, n entries) -simple_full_sync_multi_crash_cases = [ - (5, [1] * 15, 5000), - (5, [1] * 20, 5000), - (5, [1] * 25, 5000) +# 1. Number of master threads +# 2. Number of threads for each replica that crashes during full sync +# 3. Number of threads for each replica that crashes during stable sync +# 4. Number of threads for each replica that disconnects normally +# 5. Number of distinct keys that are constantly streamed +disconnect_cases = [ + # balanced + (8, [4, 4], [4, 4], [4], 10000), + (8, [2] * 6, [2] * 6, [2, 2], 10000), + # full sync heavy + (8, [4] * 6, [], [], 10000), + (8, [2] * 12, [], [], 10000), + # stable state heavy + (8, [], [4] * 6, [], 10000), + (8, [], [2] * 12, [], 10000), + # disconnect only + (8, [], [], [2] * 6, 10000) ] @pytest.mark.asyncio -@pytest.mark.skip(reason="test is currently crashing") -@pytest.mark.parametrize("t_master, t_replicas, n_keys", simple_full_sync_multi_crash_cases) -async def test_simple_full_sync_mutli_crash(df_local_factory, t_master, t_replicas, n_keys): - def data_gen(): return gen_test_data(n_keys) - - master = df_local_factory.create(port=BASE_PORT, proactor_threads=t_master) +@pytest.mark.parametrize("t_master, t_crash_fs, t_crash_ss, t_disonnect, n_keys", disconnect_cases) +async def test_disconnect(df_local_factory, t_master, t_crash_fs, t_crash_ss, t_disonnect, n_keys): + master = df_local_factory.create(port=BASE_PORT, proactor_threads=t_master,logtostdout="") replicas = [ - df_local_factory.create(port=BASE_PORT+i+1, proactor_threads=t) - for i, t in enumerate(t_replicas) + (df_local_factory.create( + port=BASE_PORT+i+1, proactor_threads=t), crash_fs) + for i, (t, crash_fs) in enumerate( + chain( + zip(t_crash_fs, repeat(0)), + zip(t_crash_ss, repeat(1)), + zip(t_disonnect, repeat(2)) + ) + ) ] - # Start master and fill with test data + # Start master master.start() c_master = aioredis.Redis(port=master.port, single_connection_client=True) - await batch_fill_data_async(c_master, data_gen()) - # Start replica tasks in parallel - tasks = [ - asyncio.create_task(run_sfs_crash_replica( - replica, master, data_gen), name="replica-"+str(replica.port)) - for replica in replicas + # Start replicas and create clients + for replica, _ in replicas: + replica.start() + + c_replicas = [ + (replica, aioredis.Redis(port=replica.port), crash_type) + for replica, crash_type in replicas ] - for task in tasks: - assert await task + def replicas_of_type(tfunc): + return [ + args for args in c_replicas + if tfunc(args[2]) + ] + + # Start data fill loop + async def fill_loop(): + local_c = aioredis.Redis( + port=master.port, single_connection_client=True) + for seed in count(1): + await batch_fill_data_async(local_c, gen_test_data(n_keys, seed=seed)) + + fill_task = asyncio.create_task(fill_loop()) + + # Run full sync + async def full_sync(replica, c_replica, crash_type): + c_replica = aioredis.Redis(port=replica.port) + await c_replica.execute_command("REPLICAOF localhost " + str(master.port)) + if crash_type == 0: + await asyncio.sleep(random.random()/100+0.01) + replica.stop(kill=True) + else: + await wait_available_async(c_replica) + + await asyncio.gather(*(full_sync(*args) for args in c_replicas)) + + # Wait for master to stream a bit more + await asyncio.sleep(0.1) + + # Check master survived full sync crashes + assert await c_master.ping() + + # Check phase-2 replicas survived + for _, c_replica, _ in replicas_of_type(lambda t: t > 0): + assert await c_replica.ping() + + # Run stable state crashes + async def stable_sync(replica, c_replica, crash_type): + await asyncio.sleep(random.random() / 100) + replica.stop(kill=True) + + await asyncio.gather(*(stable_sync(*args) for args + in replicas_of_type(lambda t: t == 1))) + + # Check master survived all crashes + assert await c_master.ping() + + # Check phase 3 replica survived + for _, c_replica, _ in replicas_of_type(lambda t: t > 1): + assert await c_replica.ping() + + # Stop streaming + fill_task.cancel() - # Check master is ok - await batch_check_data_async(c_master, data_gen()) + # Check master survived all crashes + assert await c_master.ping() - await c_master.connection_pool.disconnect() + # Check phase 3 replicas are up-to-date and there is no gap or lag + def check_gen(): return gen_test_data(n_keys//5, seed=0) + await batch_fill_data_async(c_master, check_gen()) + await asyncio.sleep(0.1) + for _, c_replica, _ in replicas_of_type(lambda t: t > 1): + await batch_check_data_async(c_replica, check_gen()) -async def run_sfs_crash_replica(replica, master, data_gen): - replica.start() - c_replica = aioredis.Redis( - port=replica.port, single_connection_client=None) + # Check disconnects + async def disconnect(replica, c_replica, crash_type): + await asyncio.sleep(random.random() / 100) + await c_replica.execute_command("REPLICAOF NO ONE") - await c_replica.execute_command("REPLICAOF localhost " + str(master.port)) + await asyncio.gather(*(disconnect(*args) for args + in replicas_of_type(lambda t: t == 2))) - # Kill the replica after a short delay - await asyncio.sleep(0.0) - replica.stop(kill=True) + # Check phase 3 replica survived + for _, c_replica, _ in replicas_of_type(lambda t: t == 2): + assert await c_replica.ping() + await batch_check_data_async(c_replica, check_gen()) - await c_replica.connection_pool.disconnect() - return True + # Check master survived all disconnects + assert await c_master.ping()