diff --git a/src/facade/conn_context.h b/src/facade/conn_context.h index c530f51d58a3..a6070dcd93c7 100644 --- a/src/facade/conn_context.h +++ b/src/facade/conn_context.h @@ -37,6 +37,8 @@ class ConnectionContext { } // A convenient proxy for redis interface. + // Use with caution -- should only be used only + // in execution paths that are Redis *only* RedisReplyBuilder* operator->(); SinkReplyBuilder* reply_builder() { @@ -50,6 +52,18 @@ class ConnectionContext { return res; } + void SendError(std::string_view str, std::string_view type = std::string_view{}) { + rbuilder_->SendError(str, type); + } + + void SendError(ErrorReply&& error) { + rbuilder_->SendError(std::move(error)); + } + + void SendSimpleString(std::string_view str) { + rbuilder_->SendSimpleString(str); + } + // connection state / properties. bool conn_closing : 1; bool req_auth : 1; diff --git a/src/facade/op_status.cc b/src/facade/op_status.cc index 778e4a53be81..6441a2ba797d 100644 --- a/src/facade/op_status.cc +++ b/src/facade/op_status.cc @@ -1,3 +1,37 @@ #include "facade/op_status.h" -namespace facade {} // namespace facade +#include "base/logging.h" +#include "facade/error.h" +#include "facade/resp_expr.h" + +namespace facade { + +std::string_view StatusToMsg(OpStatus status) { + switch (status) { + case OpStatus::OK: + return "OK"; + case OpStatus::KEY_NOTFOUND: + return kKeyNotFoundErr; + case OpStatus::WRONG_TYPE: + return kWrongTypeErr; + case OpStatus::OUT_OF_RANGE: + return kIndexOutOfRange; + case OpStatus::INVALID_FLOAT: + return kInvalidFloatErr; + case OpStatus::INVALID_INT: + return kInvalidIntErr; + case OpStatus::SYNTAX_ERR: + return kSyntaxErr; + case OpStatus::OUT_OF_MEMORY: + return kOutOfMemory; + case OpStatus::BUSY_GROUP: + return "-BUSYGROUP Consumer Group name already exists"; + case OpStatus::INVALID_NUMERIC_RESULT: + return kInvalidNumericResult; + default: + LOG(ERROR) << "Unsupported status " << status; + return "Internal error"; + } +} + +} // namespace facade diff --git a/src/facade/op_status.h b/src/facade/op_status.h index ef9e54814b6a..0f175afc208c 100644 --- a/src/facade/op_status.h +++ b/src/facade/op_status.h @@ -124,6 +124,8 @@ inline bool operator==(OpStatus st, const OpResultBase& ob) { return ob.operator==(st); } +std::string_view StatusToMsg(OpStatus status); + } // namespace facade namespace std { diff --git a/src/facade/reply_builder.cc b/src/facade/reply_builder.cc index ff85248782f5..5ca3955f911f 100644 --- a/src/facade/reply_builder.cc +++ b/src/facade/reply_builder.cc @@ -94,6 +94,22 @@ void SinkReplyBuilder::SendRaw(std::string_view raw) { Send(&v, 1); } +void SinkReplyBuilder::SendError(ErrorReply error) { + if (error.status) + return SendError(*error.status); + + string_view message_sv = visit([](auto&& str) -> string_view { return str; }, error.message); + SendError(message_sv, error.kind); +} + +void SinkReplyBuilder::SendError(OpStatus status) { + if (status == OpStatus::OK) { + SendOk(); + } else { + SendError(StatusToMsg(status)); + } +} + void SinkReplyBuilder::SendRawVec(absl::Span msg_vec) { absl::FixedArray arr(msg_vec.size()); @@ -223,14 +239,6 @@ void RedisReplyBuilder::SendError(string_view str, string_view err_type) { } } -void RedisReplyBuilder::SendError(ErrorReply error) { - if (error.status) - return SendError(*error.status); - - string_view message_sv = visit([](auto&& str) -> string_view { return str; }, error.message); - SendError(message_sv, error.kind); -} - void RedisReplyBuilder::SendProtocolError(std::string_view str) { SendError(absl::StrCat("-ERR Protocol error: ", str), "protocol_error"); } @@ -277,42 +285,6 @@ void RedisReplyBuilder::SendBulkString(std::string_view str) { return Send(v, ABSL_ARRAYSIZE(v)); } -std::string_view RedisReplyBuilder::StatusToMsg(OpStatus status) { - switch (status) { - case OpStatus::OK: - return "OK"; - case OpStatus::KEY_NOTFOUND: - return kKeyNotFoundErr; - case OpStatus::WRONG_TYPE: - return kWrongTypeErr; - case OpStatus::OUT_OF_RANGE: - return kIndexOutOfRange; - case OpStatus::INVALID_FLOAT: - return kInvalidFloatErr; - case OpStatus::INVALID_INT: - return kInvalidIntErr; - case OpStatus::SYNTAX_ERR: - return kSyntaxErr; - case OpStatus::OUT_OF_MEMORY: - return kOutOfMemory; - case OpStatus::BUSY_GROUP: - return "-BUSYGROUP Consumer Group name already exists"; - case OpStatus::INVALID_NUMERIC_RESULT: - return kInvalidNumericResult; - default: - LOG(ERROR) << "Unsupported status " << status; - return "Internal error"; - } -} - -void RedisReplyBuilder::SendError(OpStatus status) { - if (status == OpStatus::OK) { - SendOk(); - } else { - SendError(StatusToMsg(status)); - } -} - void RedisReplyBuilder::SendLong(long num) { string str = absl::StrCat(":", num, kCRLF); SendRaw(str); diff --git a/src/facade/reply_builder.h b/src/facade/reply_builder.h index dc15172f214a..3634789e109b 100644 --- a/src/facade/reply_builder.h +++ b/src/facade/reply_builder.h @@ -42,6 +42,8 @@ class SinkReplyBuilder { } virtual void SendError(std::string_view str, std::string_view type = {}) = 0; // MC and Redis + virtual void SendError(ErrorReply error); + virtual void SendError(OpStatus status); virtual void SendStored() = 0; // Reply for set commands. virtual void SendSetSkipped() = 0; @@ -177,13 +179,12 @@ class RedisReplyBuilder : public SinkReplyBuilder { void SetResp3(bool is_resp3); void SendError(std::string_view str, std::string_view type = {}) override; - virtual void SendError(ErrorReply error); + using SinkReplyBuilder::SendError; void SendMGetResponse(absl::Span) override; void SendStored() override; void SendSetSkipped() override; - virtual void SendError(OpStatus status); void SendProtocolError(std::string_view str) override; virtual void SendNullArray(); // Send *-1 @@ -206,10 +207,6 @@ class RedisReplyBuilder : public SinkReplyBuilder { static char* FormatDouble(double val, char* dest, unsigned dest_len); - // You normally should not call this - maps the status - // into the string that would be sent - static std::string_view StatusToMsg(OpStatus status); - protected: struct WrappedStrSpan : public StrSpan { size_t Size() const; diff --git a/src/facade/reply_builder_test.cc b/src/facade/reply_builder_test.cc index 8b2b297be84d..95aa045d3326 100644 --- a/src/facade/reply_builder_test.cc +++ b/src/facade/reply_builder_test.cc @@ -232,7 +232,7 @@ TEST_F(RedisReplyBuilderTest, ErrorBuiltInMessage) { OpStatus::OUT_OF_MEMORY, OpStatus::INVALID_FLOAT, OpStatus::INVALID_INT, OpStatus::SYNTAX_ERR, OpStatus::BUSY_GROUP, OpStatus::INVALID_NUMERIC_RESULT}; for (const auto& err : error_codes) { - const std::string_view error_name = RedisReplyBuilder::StatusToMsg(err); + const std::string_view error_name = StatusToMsg(err); const std::string_view error_type = GetErrorType(error_name); sink_.Clear(); @@ -251,6 +251,31 @@ TEST_F(RedisReplyBuilderTest, ErrorBuiltInMessage) { } } +TEST_F(RedisReplyBuilderTest, ErrorReplyBuiltInMessage) { + ErrorReply err{OpStatus::OUT_OF_RANGE}; + builder_->SendError(err); + ASSERT_TRUE(absl::StartsWith(str(), kErrorStart)); + ASSERT_TRUE(absl::EndsWith(str(), kCRLF)); + ASSERT_EQ(builder_->err_count().at(kIndexOutOfRange), 1); + ASSERT_EQ(str(), BuildExpectedErrorString(kIndexOutOfRange)); + + auto parsing_output = Parse(); + ASSERT_TRUE(parsing_output.Verify(SinkSize())); + ASSERT_TRUE(parsing_output.IsError()); + sink_.Clear(); + + err = ErrorReply{"e1", "e2"}; + builder_->SendError(err); + ASSERT_TRUE(absl::StartsWith(str(), kErrorStart)); + ASSERT_TRUE(absl::EndsWith(str(), kCRLF)); + ASSERT_EQ(builder_->err_count().at("e2"), 1); + ASSERT_EQ(str(), BuildExpectedErrorString("e1")); + + parsing_output = Parse(); + ASSERT_TRUE(parsing_output.Verify(SinkSize())); + ASSERT_TRUE(parsing_output.IsError()); +} + TEST_F(RedisReplyBuilderTest, ErrorNoneBuiltInMessage) { // All these op codes creating the same error message OpStatus none_unique_codes[] = {OpStatus::ENTRIES_ADDED_SMALL, OpStatus::SKIPPED, @@ -258,7 +283,7 @@ TEST_F(RedisReplyBuilderTest, ErrorNoneBuiltInMessage) { OpStatus::TIMED_OUT, OpStatus::STREAM_ID_SMALL}; uint64_t error_count = 0; for (const auto& err : none_unique_codes) { - const std::string_view error_name = RedisReplyBuilder::StatusToMsg(err); + const std::string_view error_name = StatusToMsg(err); const std::string_view error_type = GetErrorType(error_name); sink_.Clear(); diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 1b83f4af21fb..bfd61e0126d5 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -879,7 +879,7 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) const auto [cid, args_no_cmd] = FindCmd(args); if (cid == nullptr) { - return (*cntx)->SendError(ReportUnknownCmd(ArgS(args, 0))); + return cntx->SendError(ReportUnknownCmd(ArgS(args, 0))); } ConnectionContext* dfly_cntx = static_cast(cntx); @@ -899,7 +899,7 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) if (auto& exec_info = dfly_cntx->conn_state.exec_info; exec_info.IsCollecting()) exec_info.state = ConnectionState::ExecInfo::EXEC_ERROR; - (*dfly_cntx)->SendError(std::move(*err)); + dfly_cntx->SendError(std::move(*err)); return; } @@ -909,13 +909,13 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) StoredCmd stored_cmd{cid, args_no_cmd}; dfly_cntx->conn_state.exec_info.body.push_back(std::move(stored_cmd)); - return (*cntx)->SendSimpleString("QUEUED"); + return cntx->SendSimpleString("QUEUED"); } uint64_t start_ns = absl::GetCurrentTimeNanos(); if (cid->opt_mask() & CO::DENYOOM) { - int64_t used_memory = etl.GetUsedMemory(start_ns); + uint64_t used_memory = etl.GetUsedMemory(start_ns); double oom_deny_ratio = GetFlag(FLAGS_oom_deny_ratio); if (used_memory > (max_memory_limit * oom_deny_ratio)) { return cntx->reply_builder()->SendError(kOutOfMemory); diff --git a/tests/dragonfly/__init__.py b/tests/dragonfly/__init__.py index 5f4559009f9a..9d1c47c3e88b 100644 --- a/tests/dragonfly/__init__.py +++ b/tests/dragonfly/__init__.py @@ -106,7 +106,7 @@ def admin_port(self) -> int: def mc_port(self) -> int: if self.params.existing_mc_port: return self.params.existing_mc_port - return int(self.args.get("mc_port", "11211")) + return int(self.args.get("memcached_port", "11211")) @staticmethod def format_args(args): diff --git a/tests/dragonfly/replication_test.py b/tests/dragonfly/replication_test.py index eda47b5cb27d..5ce59f229c49 100644 --- a/tests/dragonfly/replication_test.py +++ b/tests/dragonfly/replication_test.py @@ -7,6 +7,7 @@ from redis import asyncio as aioredis from .utility import * from . import DflyInstanceFactory, dfly_args +import pymemcache import logging BASE_PORT = 1111 @@ -1506,3 +1507,37 @@ async def test_replicaof_flag_disconnect(df_local_factory): role = await c_replica.role() assert role[0] == b"master" + + +@pytest.mark.asyncio +async def test_df_crash_on_memcached_error(df_local_factory): + master = df_local_factory.create( + port=BASE_PORT, + memcached_port=11211, + proactor_threads=2, + ) + + replica = df_local_factory.create( + port=master.port + 1, + memcached_port=master.mc_port + 1, + proactor_threads=2, + ) + + master.start() + replica.start() + + c_master = aioredis.Redis(port=master.port) + await wait_available_async(c_master) + + c_replica = aioredis.Redis(port=replica.port) + await c_replica.execute_command(f"REPLICAOF localhost {master.port}") + await wait_available_async(c_replica) + await wait_for_replica_status(c_replica, status="up") + await c_replica.close() + + memcached_client = pymemcache.Client(f"localhost:{replica.mc_port}") + + with pytest.raises(pymemcache.exceptions.MemcacheClientError): + memcached_client.set(b"key", b"data", noreply=False) + + await c_master.close()