From e6721d816077dda1dc0fbf963dace5cbfd78e084 Mon Sep 17 00:00:00 2001 From: Vladislav Date: Tue, 27 Dec 2022 16:01:54 +0300 Subject: [PATCH] feat(server): Improved cancellation (#599) --- src/server/common.cc | 35 +++++- src/server/common.h | 45 ++++---- src/server/dflycmd.cc | 24 +++-- src/server/dflycmd.h | 5 +- src/server/replica.cc | 158 ++++++++++++++++++---------- src/server/replica.h | 10 +- tests/dragonfly/replication_test.py | 17 ++- 7 files changed, 186 insertions(+), 108 deletions(-) diff --git a/src/server/common.cc b/src/server/common.cc index 3d189537dbda..59fb593eeab9 100644 --- a/src/server/common.cc +++ b/src/server/common.cc @@ -263,6 +263,10 @@ std::string GenericError::Format() const { return absl::StrCat(ec_.message(), ":", details_); } +Context::~Context() { + JoinErrorHandler(); +} + GenericError Context::GetError() { std::lock_guard lk(mu_); return err_; @@ -273,20 +277,45 @@ const Cancellation* Context::GetCancellation() const { } void Context::Cancel() { - Error(std::make_error_code(errc::operation_canceled), "Context cancelled"); + ReportError(std::make_error_code(errc::operation_canceled), "Context cancelled"); } void Context::Reset(ErrHandler handler) { std::lock_guard lk{mu_}; + JoinErrorHandler(); err_ = {}; err_handler_ = std::move(handler); Cancellation::flag_.store(false, std::memory_order_relaxed); } -GenericError Context::Switch(ErrHandler handler) { +GenericError Context::SwitchErrorHandler(ErrHandler handler) { std::lock_guard lk{mu_}; - if (!err_) + if (!err_) { + // No need to check for the error handler - it can't be running + // if no error is set. err_handler_ = std::move(handler); + } + return err_; +} + +void Context::JoinErrorHandler() { + if (err_handler_fb_.IsJoinable()) + err_handler_fb_.Join(); +} + +GenericError Context::ReportErrorInternal(GenericError&& err) { + std::lock_guard lk{mu_}; + if (err_) + return err_; + err_ = std::move(err); + + // This context is either new or was Reset, where the handler was joined + CHECK(!err_handler_fb_.IsJoinable()); + + if (err_handler_) + err_handler_fb_ = util::fibers_ext::Fiber{err_handler_, err_}; + + Cancellation::Cancel(); return err_; } diff --git a/src/server/common.h b/src/server/common.h index 93c922f5f84c..9861bfafd22e 100644 --- a/src/server/common.h +++ b/src/server/common.h @@ -15,6 +15,7 @@ #include "facade/facade_types.h" #include "facade/op_status.h" +#include "util/fibers/fiber.h" namespace dfly { @@ -243,7 +244,8 @@ using AggregateGenericError = AggregateValue; // Context is a utility for managing error reporting and cancellation for complex tasks. // // When submitting an error with `Error`, only the first is stored (as in aggregate values). -// Then a special error handler is run, if present, and the context is cancelled. +// Then a special error handler is run, if present, and the context is cancelled. The error handler +// is run in a separate handler to free up the caller. // // Manual cancellation with `Cancel` is simulated by reporting an `errc::operation_canceled` error. // This allows running the error handler and representing this scenario as an error. @@ -255,10 +257,10 @@ class Context : protected Cancellation { Context(ErrHandler err_handler) : Cancellation{}, err_{}, err_handler_{std::move(err_handler)} { } - // Cancels the context by submitting an `errc::operation_canceled` error. - void Cancel(); - using Cancellation::IsCancelled; + ~Context(); + void Cancel(); // Cancels the context by submitting an `errc::operation_canceled` error. + using Cancellation::IsCancelled; const Cancellation* GetCancellation() const; GenericError GetError(); @@ -266,27 +268,11 @@ class Context : protected Cancellation { // Report an error by submitting arguments for GenericError. // If this is the first error that occured, then the error handler is run // and the context is cancelled. - // - // Note: this function blocks when called from inside an error handler. - template GenericError Error(T... ts) { - if (!mu_.try_lock()) // TODO: Maybe use two separate locks. - return GenericError{std::forward(ts)...}; - - std::lock_guard lk{mu_, std::adopt_lock}; - if (err_) - return err_; - - GenericError new_err{std::forward(ts)...}; - if (err_handler_) - err_handler_(new_err); - - err_ = std::move(new_err); - Cancellation::Cancel(); - - return err_; + template GenericError ReportError(T... ts) { + return ReportErrorInternal(GenericError{std::forward(ts)...}); } - // Reset error and cancellation flag, assign new error handler. + // Wait for error handler to stop, reset error and cancellation flag, assign new error handler. void Reset(ErrHandler handler); // Atomically replace the error handler if no error is present, and return the @@ -295,12 +281,21 @@ class Context : protected Cancellation { // Beware, never do this manually in two steps. If you check for cancellation, // set the error handler and initialize resources, then the new error handler // will never run if the context was cancelled between the first two steps. - GenericError Switch(ErrHandler handler); + GenericError SwitchErrorHandler(ErrHandler handler); + + // If any error handler is running, wait for it to stop. + void JoinErrorHandler(); + + private: + // Report error. + GenericError ReportErrorInternal(GenericError&& err); private: GenericError err_; - ErrHandler err_handler_; ::boost::fibers::mutex mu_; + + ErrHandler err_handler_; + ::util::fibers_ext::Fiber err_handler_fb_; }; struct ScanOpts { diff --git a/src/server/dflycmd.cc b/src/server/dflycmd.cc index 24f14e09b938..bceb81b3c320 100644 --- a/src/server/dflycmd.cc +++ b/src/server/dflycmd.cc @@ -314,7 +314,7 @@ void DflyCmd::Expire(CmdArgList args, ConnectionContext* cntx) { } OpStatus DflyCmd::StartFullSyncInThread(FlowInfo* flow, Context* cntx, EngineShard* shard) { - DCHECK(!flow->full_sync_fb.joinable()); + DCHECK(!flow->full_sync_fb.IsJoinable()); SaveMode save_mode = shard == nullptr ? SaveMode::SUMMARY : SaveMode::SINGLE_SHARD; flow->saver.reset(new RdbSaver(flow->conn->socket(), save_mode, false)); @@ -341,8 +341,8 @@ void DflyCmd::StopFullSyncInThread(FlowInfo* flow, EngineShard* shard) { } // Wait for full sync to finish. - if (flow->full_sync_fb.joinable()) { - flow->full_sync_fb.join(); + if (flow->full_sync_fb.IsJoinable()) { + flow->full_sync_fb.Join(); } // Reset cleanup and saver @@ -382,18 +382,18 @@ void DflyCmd::FullSyncFb(FlowInfo* flow, Context* cntx) { } if (ec) { - cntx->Error(ec); + cntx->ReportError(ec); return; } if (ec = saver->SaveBody(cntx->GetCancellation(), nullptr); ec) { - cntx->Error(ec); + cntx->ReportError(ec); return; } ec = flow->conn->socket()->Write(io::Buffer(flow->eof_token)); if (ec) { - cntx->Error(ec); + cntx->ReportError(ec); return; } } @@ -406,9 +406,8 @@ uint32_t DflyCmd::CreateSyncSession() { auto err_handler = [this, sync_id](const GenericError& err) { LOG(INFO) << "Replication error: " << err.Format(); - // Stop replication in case of error. - // StopReplication needs to run async to prevent blocking - // the error handler. + // Spawn external fiber to allow destructing the context from outside + // and return from the handler immediately. ::boost::fibers::fiber{&DflyCmd::StopReplication, this, sync_id}.detach(); }; @@ -473,8 +472,8 @@ void DflyCmd::CancelReplication(uint32_t sync_id, shared_ptr replic } } - if (flow->full_sync_fb.joinable()) { - flow->full_sync_fb.join(); + if (flow->full_sync_fb.IsJoinable()) { + flow->full_sync_fb.Join(); } }); @@ -484,6 +483,9 @@ void DflyCmd::CancelReplication(uint32_t sync_id, shared_ptr replic replica_infos_.erase(sync_id); } + // Wait for error handler to quit. + replica_ptr->cntx.JoinErrorHandler(); + LOG(INFO) << "Evicted sync session " << sync_id; } diff --git a/src/server/dflycmd.h b/src/server/dflycmd.h index aafc10424f56..3603f9b60430 100644 --- a/src/server/dflycmd.h +++ b/src/server/dflycmd.h @@ -12,6 +12,7 @@ #include #include "server/conn_context.h" +#include "util/fibers/fiber.h" namespace facade { class RedisReplyBuilder; @@ -91,8 +92,8 @@ class DflyCmd { facade::Connection* conn; - ::boost::fibers::fiber full_sync_fb; // Full sync fiber. - std::unique_ptr saver; // Saver used by the full sync phase. + util::fibers_ext::Fiber full_sync_fb; // Full sync fiber. + std::unique_ptr saver; // Saver used by the full sync phase. std::string eof_token; std::function cleanup; // Optional cleanup for cancellation. diff --git a/src/server/replica.cc b/src/server/replica.cc index 4533841f5ad9..485bb93d49b7 100644 --- a/src/server/replica.cc +++ b/src/server/replica.cc @@ -7,6 +7,7 @@ extern "C" { #include "redis/rdb.h" } +#include #include #include #include @@ -199,23 +200,14 @@ void Replica::MainReplicationFb() { // 3. Initiate full sync if ((state_mask_ & R_SYNC_OK) == 0) { - // Make sure we're in LOADING state. - if (service_.SwitchState(GlobalState::ACTIVE, GlobalState::LOADING) != GlobalState::LOADING) { - state_mask_ = 0; - continue; - } - if (HasDflyMaster()) ec = InitiateDflySync(); else ec = InitiatePSync(); - service_.SwitchState(GlobalState::LOADING, GlobalState::ACTIVE); - if (ec) { LOG(WARNING) << "Error syncing " << ec << " " << ec.message(); state_mask_ &= R_ENABLED; // reset all flags besides R_ENABLED - JoinAllFlows(); continue; } @@ -230,10 +222,11 @@ void Replica::MainReplicationFb() { else ec = ConsumeRedisStream(); - JoinAllFlows(); state_mask_ &= ~R_SYNC_OK; } + cntx_.JoinErrorHandler(); + VLOG(1) << "Main replication fiber finished"; } @@ -385,6 +378,13 @@ error_code Replica::InitiatePSync() { SocketSource ss{sock_.get()}; io::PrefixSource ps{io_buf.InputBuffer(), &ss}; + // Set LOADING state. + // TODO: Flush db on retry. + CHECK(service_.SwitchState(GlobalState::ACTIVE, GlobalState::LOADING) == GlobalState::LOADING); + absl::Cleanup cleanup = [this]() { + service_.SwitchState(GlobalState::LOADING, GlobalState::ACTIVE); + }; + RdbLoader loader(NULL); loader.set_source_limit(snapshot_size); // TODO: to allow registering callbacks within loader to send '\n' pings back to master. @@ -428,8 +428,16 @@ error_code Replica::InitiatePSync() { // Initialize and start sub-replica for each flow. error_code Replica::InitiateDflySync() { - DCHECK_GT(num_df_flows_, 0u); + absl::Cleanup cleanup = [this]() { + // We do the following operations regardless of outcome. + JoinAllFlows(); + service_.SwitchState(GlobalState::LOADING, GlobalState::ACTIVE); + }; + + // Initialize MultiShardExecution. multi_shard_exe_.reset(new MultiShardExecution()); + + // Initialize shard flows. shard_flows_.resize(num_df_flows_); for (unsigned i = 0; i < num_df_flows_; ++i) { shard_flows_[i].reset(new Replica(master_context_, i, &service_, multi_shard_exe_)); @@ -438,33 +446,66 @@ error_code Replica::InitiateDflySync() { // Blocked on until all flows got full sync cut. fibers_ext::BlockingCounter sync_block{num_df_flows_}; + // Switch to new error handler that closes flow sockets. auto err_handler = [this, sync_block](const auto& ge) mutable { - sync_block.Cancel(); // Unblock this function. - DefaultErrorHandler(ge); // Close sockets to unblock flows. + // Unblock this function. + sync_block.Cancel(); + + // Make sure the flows are not in a state transition + lock_guard lk{flows_op_mu_}; + + // Unblock all sockets. + DefaultErrorHandler(ge); + for (auto& flow : shard_flows_) + flow->CloseSocket(); }; - RETURN_ON_ERR(cntx_.Switch(std::move(err_handler))); + RETURN_ON_ERR(cntx_.SwitchErrorHandler(std::move(err_handler))); + + // Make sure we're in LOADING state. + // TODO: Flush db on retry. + CHECK(service_.SwitchState(GlobalState::ACTIVE, GlobalState::LOADING) == GlobalState::LOADING); // Start full sync flows. - auto partition = Partition(num_df_flows_); - shard_set->pool()->AwaitFiberOnAll([&](unsigned index, auto*) { - for (auto id : partition[index]) { - auto ec = shard_flows_[id]->StartFullSyncFlow(sync_block, &cntx_); - if (ec) - cntx_.Error(ec); - } - }); + { + auto partition = Partition(num_df_flows_); + auto shard_cb = [&](unsigned index, auto*) { + for (auto id : partition[index]) { + auto ec = shard_flows_[id]->StartFullSyncFlow(sync_block, &cntx_); + if (ec) + cntx_.ReportError(ec); + } + }; + + // Lock to prevent the error handler from running instantly + // while the flows are in a mixed state. + lock_guard lk{flows_op_mu_}; + shard_set->pool()->AwaitFiberOnAll(std::move(shard_cb)); + } + RETURN_ON_ERR(cntx_.GetError()); // Send DFLY SYNC. - if (auto ec = SendNextPhaseRequest(); ec) { - return cntx_.Error(ec); + if (auto ec = SendNextPhaseRequest(false); ec) { + return cntx_.ReportError(ec); } + // Wait for all flows to receive full sync cut. // In case of an error, this is unblocked by the error handler. LOG(INFO) << "Waiting for all full sync cut confirmations"; sync_block.Wait(); - LOG(INFO) << "Full sync finished"; + // Check if we woke up due to cancellation. + if (cntx_.IsCancelled()) + return cntx_.GetError(); + + // Send DFLY STARTSTABLE. + if (auto ec = SendNextPhaseRequest(true); ec) { + return cntx_.ReportError(ec); + } + + // Joining flows and resetting state is done by cleanup. + + LOG(INFO) << "Full sync finished "; return cntx_.GetError(); } @@ -515,40 +556,48 @@ error_code Replica::ConsumeRedisStream() { } error_code Replica::ConsumeDflyStream() { - // Send DFLY STARTSTABLE. - if (auto ec = SendNextPhaseRequest(); ec) { - return cntx_.Error(ec); + // Set new error handler that closes flow sockets. + auto err_handler = [this](const auto& ge) { + // Make sure the flows are not in a state transition + lock_guard lk{flows_op_mu_}; + DefaultErrorHandler(ge); + for (auto& flow : shard_flows_) + flow->CloseSocket(); + }; + RETURN_ON_ERR(cntx_.SwitchErrorHandler(std::move(err_handler))); + + // Transition flows into stable sync. + { + auto partition = Partition(num_df_flows_); + auto shard_cb = [&](unsigned index, auto*) { + const auto& local_ids = partition[index]; + for (unsigned id : local_ids) { + auto ec = shard_flows_[id]->StartStableSyncFlow(&cntx_); + if (ec) + cntx_.ReportError(ec); + } + }; + + // Lock to prevent error handler from running on mixed state. + lock_guard lk{flows_op_mu_}; + shard_set->pool()->AwaitFiberOnAll(std::move(shard_cb)); } - // Wait for all flows to finish full sync. JoinAllFlows(); - RETURN_ON_ERR(cntx_.Switch(absl::bind_front(&Replica::DefaultErrorHandler, this))); - - vector> partition = Partition(num_df_flows_); - shard_set->pool()->AwaitFiberOnAll([&](unsigned index, auto*) { - const auto& local_ids = partition[index]; - for (unsigned id : local_ids) { - auto ec = shard_flows_[id]->StartStableSyncFlow(&cntx_); - if (ec) - cntx_.Error(ec); - } - }); + // The only option to unblock is to cancel the context. + CHECK(cntx_.GetError()); return cntx_.GetError(); } -void Replica::CloseAllSockets() { +void Replica::CloseSocket() { if (sock_) { sock_->proactor()->Await([this] { auto ec = sock_->Shutdown(SHUT_RDWR); LOG_IF(ERROR, ec) << "Could not shutdown socket " << ec; }); } - - for (auto& flow : shard_flows_) { - flow->CloseAllSockets(); - } } void Replica::JoinAllFlows() { @@ -560,16 +609,18 @@ void Replica::JoinAllFlows() { } void Replica::DefaultErrorHandler(const GenericError& err) { - CloseAllSockets(); + CloseSocket(); } -error_code Replica::SendNextPhaseRequest() { +error_code Replica::SendNextPhaseRequest(bool stable) { ReqSerializer serializer{sock_.get()}; // Ask master to start sending replication stream - string request = (state_mask_ & R_SYNC_OK) ? "STARTSTABLE" : "SYNC"; - RETURN_ON_ERR( - SendCommand(StrCat("DFLY ", request, " ", master_context_.dfly_session_id), &serializer)); + string_view kind = (stable) ? "STARTSTABLE"sv : "SYNC"sv; + string request = StrCat("DFLY ", kind, " ", master_context_.dfly_session_id); + + LOG(INFO) << "Sending: " << request; + RETURN_ON_ERR(SendCommand(request, &serializer)); base::IoBuf io_buf{128}; unsigned consumed = 0; @@ -657,7 +708,7 @@ void Replica::FullSyncDflyFb(string eof_token, fibers_ext::BlockingCounter bc, C // Load incoming rdb stream. if (std::error_code ec = loader.Load(&ps); ec) { - cntx->Error(ec, "Error loading rdb format"); + cntx->ReportError(ec, "Error loading rdb format"); return; } @@ -670,7 +721,8 @@ void Replica::FullSyncDflyFb(string eof_token, fibers_ext::BlockingCounter bc, C chained_tail.ReadAtLeast(io::MutableBytes{buf.get(), eof_token.size()}, eof_token.size()); if (!res || *res != eof_token.size()) { - cntx->Error(std::make_error_code(errc::protocol_error), "Error finding eof token in stream"); + cntx->ReportError(std::make_error_code(errc::protocol_error), + "Error finding eof token in stream"); return; } } @@ -704,7 +756,7 @@ void Replica::StableSyncDflyFb(Context* cntx) { while (!cntx->IsCancelled()) { auto res = reader.ReadEntry(&ps); if (!res) { - cntx->Error(res.error(), "Journal format error"); + cntx->ReportError(res.error(), "Journal format error"); return; } ExecuteEntry(&executor, res.value()); diff --git a/src/server/replica.h b/src/server/replica.h index 93df10740168..317f18d17b35 100644 --- a/src/server/replica.h +++ b/src/server/replica.h @@ -89,10 +89,11 @@ class Replica { std::error_code ConsumeRedisStream(); // Redis stable state. std::error_code ConsumeDflyStream(); // Dragonfly stable state. - void CloseAllSockets(); // Close all sockets. - void JoinAllFlows(); // Join all flows if possible. + void CloseSocket(); // Close replica sockets. + void JoinAllFlows(); // Join all flows if possible. - std::error_code SendNextPhaseRequest(); // Send DFLY SYNC or DFLY STARTSTABLE. + // Send DFLY SYNC or DFLY STARTSTABLE if stable is true. + std::error_code SendNextPhaseRequest(bool stable); void DefaultErrorHandler(const GenericError& err); @@ -180,6 +181,9 @@ class Replica { ::boost::fibers::fiber sync_fb_; std::vector> shard_flows_; + // Guard operations where flows might be in a mixed state (transition/setup) + ::boost::fibers::mutex flows_op_mu_; + std::unique_ptr leftover_buf_; std::unique_ptr parser_; facade::RespVec resp_args_; diff --git a/tests/dragonfly/replication_test.py b/tests/dragonfly/replication_test.py index d9d47794f0ca..d439b2a3d9c2 100644 --- a/tests/dragonfly/replication_test.py +++ b/tests/dragonfly/replication_test.py @@ -276,16 +276,8 @@ async def test_disconnect_master(df_local_factory, t_master, t_replicas, n_rando c_replicas = [aioredis.Redis(port=replica.port) for replica in replicas] - async def full_sync(c_replica): - try: - await c_replica.execute_command("REPLICAOF localhost " + str(master.port)) - await wait_available_async(c_replica) - except aioredis.ResponseError as e: - # This should mean master crashed during greet phase - pass - async def crash_master_fs(): - await asyncio.sleep(random.random() / 10 + 0.01) + await asyncio.sleep(random.random() / 10 + 0.1 * len(replicas)) master.stop(kill=True) async def start_master(): @@ -296,8 +288,11 @@ async def start_master(): await start_master() - # Crash master during full sync - await asyncio.gather(*(full_sync(c) for c in c_replicas), crash_master_fs()) + # Crash master during full sync, but with all passing initial connection phase + await asyncio.gather(*(c_replica.execute_command("REPLICAOF localhost " + str(master.port)) + for c_replica in c_replicas), crash_master_fs()) + + await asyncio.sleep(1 + len(replicas) * 0.5) for _ in range(n_random_crashes): await start_master()