diff --git a/src/server/server_family.cc b/src/server/server_family.cc index 9c22416366bf..bd1f3863da3f 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -491,10 +491,26 @@ void ServerFamily::Init(util::AcceptServer* acceptor, std::vectorAwait( [this, &flag]() { this->Replicate(flag.host, flag.port); }); - return; // DONT load any snapshots + } else { // load from snapshot only if --replicaof is empty + LoadFromSnapshot(); } - const auto load_path_result = snapshot_storage_->LoadPath(flag_dir, GetFlag(FLAGS_dbfilename)); + const auto create_snapshot_schedule_fb = [this] { + snapshot_schedule_fb_ = + service_.proactor_pool().GetNextProactor()->LaunchFiber([this] { SnapshotScheduling(); }); + }; + config_registry.RegisterMutable( + "snapshot_cron", [this, create_snapshot_schedule_fb](const absl::CommandLineFlag& flag) { + JoinSnapshotSchedule(); + create_snapshot_schedule_fb(); + return true; + }); + create_snapshot_schedule_fb(); +} + +void ServerFamily::LoadFromSnapshot() { + const auto load_path_result = + snapshot_storage_->LoadPath(GetFlag(FLAGS_dir), GetFlag(FLAGS_dbfilename)); if (load_path_result) { const std::string load_path = *load_path_result; if (!load_path.empty()) { @@ -507,19 +523,6 @@ void ServerFamily::Init(util::AcceptServer* acceptor, std::vectorLaunchFiber([this] { SnapshotScheduling(); }); - }; - config_registry.RegisterMutable( - "snapshot_cron", [this, create_snapshot_schedule_fb](const absl::CommandLineFlag& flag) { - JoinSnapshotSchedule(); - create_snapshot_schedule_fb(); - return true; - }); - - create_snapshot_schedule_fb(); } void ServerFamily::JoinSnapshotSchedule() { @@ -1937,9 +1940,14 @@ void ServerFamily::Hello(CmdArgList args, ConnectionContext* cntx) { void ServerFamily::ReplicaOfInternal(string_view host, string_view port_sv, ConnectionContext* cntx, ActionOnConnectionFail on_err) { LOG(INFO) << "Replicating " << host << ":" << port_sv; - unique_lock lk(replicaof_mu_); // Only one REPLICAOF command can run at a time + // We should not execute replica of command while loading from snapshot. + if (ServerState::tlocal()->is_master && service_.GetGlobalState() == GlobalState::LOADING) { + cntx->SendError("Can not execute during LOADING"); + return; + } + // If NO ONE was supplied, just stop the current replica (if it exists) if (IsReplicatingNoOne(host, port_sv)) { if (!ServerState::tlocal()->is_master) { diff --git a/src/server/server_family.h b/src/server/server_family.h index 6ef954a13419..7222f3b73a19 100644 --- a/src/server/server_family.h +++ b/src/server/server_family.h @@ -204,6 +204,7 @@ class ServerFamily { private: void JoinSnapshotSchedule(); + void LoadFromSnapshot(); uint32_t shard_count() const { return shard_set->size(); diff --git a/tests/dragonfly/replication_test.py b/tests/dragonfly/replication_test.py index eedbc5c545ea..bc874f276e6e 100644 --- a/tests/dragonfly/replication_test.py +++ b/tests/dragonfly/replication_test.py @@ -1806,3 +1806,33 @@ async def test_client_pause_with_replica(df_local_factory, df_seeder_factory): assert await seeder.compare(capture, port=replica.port) await disconnect_clients(c_master, c_replica) + + +async def test_replicaof_reject_on_load(df_local_factory, df_seeder_factory): + tmp_file_name = "".join(random.choices(string.ascii_letters, k=10)) + master = df_local_factory.create() + replica = df_local_factory.create(dbfilename=f"dump_{tmp_file_name}") + df_local_factory.start_all([master, replica]) + + seeder = df_seeder_factory.create(port=replica.port, keys=30000) + await seeder.run(target_deviation=0.1) + c_replica = replica.client() + dbsize = await c_replica.dbsize() + assert dbsize >= 9000 + + replica.stop() + replica.start() + c_replica = replica.client() + # Check replica of not alowed while loading snapshot + try: + await c_replica.execute_command(f"REPLICAOF localhost {master.port}") + assert False + except aioredis.ResponseError as e: + assert "Can not execute during LOADING" in str(e) + # Check one we finish loading snapshot replicaof success + await wait_available_async(c_replica) + await c_replica.execute_command(f"REPLICAOF localhost {master.port}") + + await c_replica.close() + master.stop() + replica.stop()