diff --git a/src/facade/conn_context.cc b/src/facade/conn_context.cc index 7528c5e4b82f..5e7399d3fa0b 100644 --- a/src/facade/conn_context.cc +++ b/src/facade/conn_context.cc @@ -9,32 +9,9 @@ #include "facade/dragonfly_connection.h" #include "facade/reply_builder.h" -ABSL_FLAG(bool, experimental_new_io, true, - "Use new replying code - should " - "reduce latencies for pipelining"); - namespace facade { -ConnectionContext::ConnectionContext(::io::Sink* stream, Connection* owner) : owner_(owner) { - if (owner) { - protocol_ = owner->protocol(); - } - - if (stream) { - switch (protocol_) { - case Protocol::REDIS: { - RedisReplyBuilder* rb = absl::GetFlag(FLAGS_experimental_new_io) - ? new RedisReplyBuilder2(stream) - : new RedisReplyBuilder(stream); - rbuilder_.reset(rb); - break; - } - case Protocol::MEMCACHE: - rbuilder_.reset(new MCReplyBuilder(stream)); - break; - } - } - +ConnectionContext::ConnectionContext(Connection* owner) : owner_(owner) { conn_closing = false; req_auth = false; replica_conn = false; @@ -49,7 +26,7 @@ ConnectionContext::ConnectionContext(::io::Sink* stream, Connection* owner) : ow } size_t ConnectionContext::UsedMemory() const { - return dfly::HeapSize(rbuilder_) + dfly::HeapSize(authed_username) + dfly::HeapSize(acl_commands); + return dfly::HeapSize(authed_username) + dfly::HeapSize(acl_commands); } } // namespace facade diff --git a/src/facade/conn_context.h b/src/facade/conn_context.h index 70f82e403e13..b87e901b473d 100644 --- a/src/facade/conn_context.h +++ b/src/facade/conn_context.h @@ -19,7 +19,7 @@ class Connection; class ConnectionContext { public: - ConnectionContext(::io::Sink* stream, Connection* owner); + explicit ConnectionContext(Connection* owner); virtual ~ConnectionContext() { } @@ -32,14 +32,6 @@ class ConnectionContext { return owner_; } - Protocol protocol() const { - return protocol_; - } - - SinkReplyBuilder* reply_builder_old() { - return rbuilder_.get(); - } - virtual size_t UsedMemory() const; // connection state / properties. @@ -71,8 +63,6 @@ class ConnectionContext { private: Connection* owner_; - Protocol protocol_ = Protocol::REDIS; - std::unique_ptr rbuilder_; }; } // namespace facade diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index 995e4b6e01b2..d2bd4935d8b0 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -86,6 +86,9 @@ ABSL_FLAG(bool, migrate_connections, true, "they operate. Currently this is only supported for Lua script invocations, and can " "happen at most once per connection."); +ABSL_FLAG(bool, experimental_new_io, true, + "Use new replying code - should reduce latencies for pipelining"); + using namespace util; using absl::GetFlag; using nonstd::make_unexpected; @@ -487,14 +490,15 @@ void Connection::DispatchOperations::operator()(Connection::PipelineMessage& msg DVLOG(2) << "Dispatching pipeline: " << ToSV(msg.args.front()); self->service_->DispatchCommand(CmdArgList{msg.args.data(), msg.args.size()}, - self->reply_builder_, self->cc_.get()); + self->reply_builder_.get(), self->cc_.get()); self->last_interaction_ = time(nullptr); self->skip_next_squashing_ = false; } void Connection::DispatchOperations::operator()(const Connection::MCPipelineMessage& msg) { - self->service_->DispatchMC(msg.cmd, msg.value, static_cast(self->reply_builder_), + self->service_->DispatchMC(msg.cmd, msg.value, + static_cast(self->reply_builder_.get()), self->cc_.get()); self->last_interaction_ = time(nullptr); } @@ -535,25 +539,24 @@ void UpdateLibNameVerMap(const string& name, const string& ver, int delta) { } } // namespace -Connection::Connection(Protocol protocol, util::HttpListenerBase* http_listener, SSL_CTX* ctx, +Connection::Connection(Protocol proto, util::HttpListenerBase* http_listener, SSL_CTX* ctx, ServiceInterface* service) : io_buf_(kMinReadSize), + protocol_(proto), http_listener_(http_listener), ssl_ctx_(ctx), service_(service), flags_(0) { static atomic_uint32_t next_id{1}; - protocol_ = protocol; - constexpr size_t kReqSz = sizeof(Connection::PipelineMessage); static_assert(kReqSz <= 256 && kReqSz >= 200); - switch (protocol) { - case Protocol::REDIS: + switch (proto) { + case REDIS: redis_parser_.reset(new RedisParser(GetFlag(FLAGS_max_multi_bulk_len))); break; - case Protocol::MEMCACHE: + case MEMCACHE: memcache_parser_.reset(new MemcacheParser); break; } @@ -724,8 +727,7 @@ void Connection::HandleRequests() { // because both Write and Recv internally check if the socket was shut // down and return with an error accordingly. if (http_res && socket_->IsOpen()) { - cc_.reset(service_->CreateContext(socket_.get(), this)); - reply_builder_ = cc_->reply_builder_old(); + cc_.reset(service_->CreateContext(this)); if (*http_res) { VLOG(1) << "HTTP1.1 identified"; @@ -749,7 +751,18 @@ void Connection::HandleRequests() { if (breaker_cb_) { socket_->RegisterOnErrorCb([this](int32_t mask) { this->OnBreakCb(mask); }); } - + switch (protocol_) { + case REDIS: { + RedisReplyBuilder* rb = absl::GetFlag(FLAGS_experimental_new_io) + ? new RedisReplyBuilder2(socket_.get()) + : new RedisReplyBuilder(socket_.get()); + reply_builder_.reset(rb); + break; + } + case MEMCACHE: + reply_builder_.reset(new MCReplyBuilder(socket_.get())); + break; + } ConnectionFlow(); socket_->CancelOnErrorCb(); // noop if nothing is registered. @@ -757,7 +770,7 @@ void Connection::HandleRequests() { VLOG(1) << "Closed connection for peer " << GetClientInfo(fb2::ProactorBase::me()->GetPoolIndex()); cc_.reset(); - reply_builder_ = nullptr; + reply_builder_.reset(); } } @@ -929,6 +942,8 @@ io::Result Connection::CheckForHttpProto() { } void Connection::ConnectionFlow() { + DCHECK(reply_builder_); + ++stats_->num_conns; ++stats_->conn_received_cnt; stats_->read_buf_capacity += io_buf_.Capacity(); @@ -986,7 +1001,7 @@ void Connection::ConnectionFlow() { VLOG(1) << "Error parser status " << parser_error_; if (redis_parser_) { - SendProtocolError(RedisParser::Result(parser_error_), reply_builder_); + SendProtocolError(RedisParser::Result(parser_error_), reply_builder_.get()); } else { DCHECK(memcache_parser_); reply_builder_->SendProtocolError("bad command line format"); @@ -1089,7 +1104,7 @@ Connection::ParserStatus Connection::ParseRedis() { auto dispatch_sync = [this, &parse_args, &cmd_vec] { RespExpr::VecToArgList(parse_args, &cmd_vec); - service_->DispatchCommand(absl::MakeSpan(cmd_vec), reply_builder_, cc_.get()); + service_->DispatchCommand(absl::MakeSpan(cmd_vec), reply_builder_.get(), cc_.get()); }; auto dispatch_async = [this, &parse_args, tlh = mi_heap_get_backing()]() -> MessageHandle { return {FromArgs(std::move(parse_args), tlh)}; @@ -1134,14 +1149,14 @@ auto Connection::ParseMemcache() -> ParserStatus { string_view value; auto dispatch_sync = [this, &cmd, &value] { - service_->DispatchMC(cmd, value, static_cast(reply_builder_), cc_.get()); + service_->DispatchMC(cmd, value, static_cast(reply_builder_.get()), cc_.get()); }; auto dispatch_async = [&cmd, &value]() -> MessageHandle { return {make_unique(std::move(cmd), value)}; }; - MCReplyBuilder* builder = static_cast(reply_builder_); + MCReplyBuilder* builder = static_cast(reply_builder_.get()); do { string_view str = ToSV(io_buf_.InputBuffer()); @@ -1358,7 +1373,7 @@ bool Connection::ShouldEndDispatchFiber(const MessageHandle& msg) { void Connection::SquashPipeline() { DCHECK_EQ(dispatch_q_.size(), pending_pipeline_cmd_cnt_); - DCHECK_EQ(reply_builder_->type(), SinkReplyBuilder::REDIS); // Only Redis is supported. + DCHECK_EQ(reply_builder_->type(), REDIS); // Only Redis is supported. vector squash_cmds; squash_cmds.reserve(dispatch_q_.size()); @@ -1374,7 +1389,7 @@ void Connection::SquashPipeline() { cc_->async_dispatch = true; size_t dispatched = - service_->DispatchManyCommands(absl::MakeSpan(squash_cmds), reply_builder_, cc_.get()); + service_->DispatchManyCommands(absl::MakeSpan(squash_cmds), reply_builder_.get(), cc_.get()); if (pending_pipeline_cmd_cnt_ == squash_cmds.size()) { // Flush if no new commands appeared reply_builder_->FlushBatch(); @@ -1397,7 +1412,7 @@ void Connection::SquashPipeline() { } void Connection::ClearPipelinedMessages() { - DispatchOperations dispatch_op{reply_builder_, this}; + DispatchOperations dispatch_op{reply_builder_.get(), this}; // Recycle messages even from disconnecting client to keep properly track of memory stats // As well as to avoid pubsub backpressure leakege. @@ -1445,7 +1460,7 @@ std::string Connection::DebugInfo() const { void Connection::ExecutionFiber() { ThisFiber::SetName("ExecutionFiber"); - DispatchOperations dispatch_op{reply_builder_, this}; + DispatchOperations dispatch_op{reply_builder_.get(), this}; size_t squashing_threshold = GetFlag(FLAGS_pipeline_squash); @@ -1809,7 +1824,7 @@ Connection::MemoryUsage Connection::GetMemoryUsage() const { size_t mem = sizeof(*this) + dfly::HeapSize(dispatch_q_) + dfly::HeapSize(name_) + dfly::HeapSize(tmp_parse_args_) + dfly::HeapSize(tmp_cmd_vec_) + dfly::HeapSize(memcache_parser_) + dfly::HeapSize(redis_parser_) + - dfly::HeapSize(cc_); + dfly::HeapSize(cc_) + dfly::HeapSize(reply_builder_); // We add a hardcoded 9k value to accomodate for the part of the Fiber stack that is in use. // The allocated stack is actually larger (~130k), but only a small fraction of that (9k diff --git a/src/facade/dragonfly_connection.h b/src/facade/dragonfly_connection.h index bc1af66c30b2..e0a4841603b0 100644 --- a/src/facade/dragonfly_connection.h +++ b/src/facade/dragonfly_connection.h @@ -57,7 +57,7 @@ class Connection : public util::Connection { struct QueueBackpressure; public: - Connection(Protocol protocol, util::HttpListenerBase* http_listener, SSL_CTX* ctx, + Connection(Protocol type, util::HttpListenerBase* http_listener, SSL_CTX* ctx, ServiceInterface* service); ~Connection(); @@ -269,10 +269,6 @@ class Connection : public util::Connection { bool IsMain() const; - Protocol protocol() const { - return protocol_; - } - void SetName(std::string name); void SetLibName(std::string name); @@ -404,9 +400,7 @@ class Connection : public util::Connection { Protocol protocol_; ConnectionStats* stats_ = nullptr; - // cc_->reply_builder may change during the lifetime of the connection, due to injections. - // This is a pointer to the original, socket based reply builder that never changes. - SinkReplyBuilder* reply_builder_ = nullptr; + std::unique_ptr reply_builder_; util::HttpListenerBase* http_listener_; SSL_CTX* ssl_ctx_; diff --git a/src/facade/dragonfly_listener.h b/src/facade/dragonfly_listener.h index 7ad8bc0e39e8..106884424e81 100644 --- a/src/facade/dragonfly_listener.h +++ b/src/facade/dragonfly_listener.h @@ -50,6 +50,10 @@ class Listener : public util::ListenerInterface { bool IsPrivilegedInterface() const; bool IsMainInterface() const; + Protocol protocol() const { + return protocol_; + } + private: util::Connection* NewConnection(ProactorBase* proactor) final; ProactorBase* PickConnectionProactor(util::FiberSocketBase* sock) final; diff --git a/src/facade/facade.cc b/src/facade/facade.cc index 52a6d7cea5e9..01612cf2a747 100644 --- a/src/facade/facade.cc +++ b/src/facade/facade.cc @@ -209,4 +209,17 @@ ostream& operator<<(ostream& os, facade::RespSpan ras) { return os; } +ostream& operator<<(ostream& os, facade::Protocol p) { + switch (p) { + case facade::REDIS: + os << "REDIS"; + break; + case facade::MEMCACHE: + os << "MEMCACHE"; + break; + } + + return os; +} + } // namespace std diff --git a/src/facade/facade_types.h b/src/facade/facade_types.h index 1dbcc9087161..be3808437f4d 100644 --- a/src/facade/facade_types.h +++ b/src/facade/facade_types.h @@ -33,7 +33,7 @@ constexpr size_t kSanitizerOverhead = 0u; #endif #endif -enum class Protocol : uint8_t { MEMCACHE = 1, REDIS = 2 }; +enum Protocol : uint8_t { MEMCACHE = 1, REDIS = 2 }; using MutableSlice = std::string_view; using CmdArgList = absl::Span; @@ -189,5 +189,5 @@ void ResetStats(); namespace std { ostream& operator<<(ostream& os, facade::CmdArgList args); - +ostream& operator<<(ostream& os, facade::Protocol proto); } // namespace std diff --git a/src/facade/ok_main.cc b/src/facade/ok_main.cc index 0b8b06823e2c..9b74896e4ca0 100644 --- a/src/facade/ok_main.cc +++ b/src/facade/ok_main.cc @@ -37,8 +37,8 @@ class OkService : public ServiceInterface { builder->SendError(""); } - ConnectionContext* CreateContext(util::FiberSocketBase* peer, Connection* owner) final { - return new ConnectionContext{peer, owner}; + ConnectionContext* CreateContext(Connection* owner) final { + return new ConnectionContext{owner}; } }; @@ -46,7 +46,7 @@ void RunEngine(ProactorPool* pool, AcceptServer* acceptor) { OkService service; pool->Await([](auto*) { tl_facade_stats = new FacadeStats; }); - acceptor->AddListener(GetFlag(FLAGS_port), new Listener{Protocol::REDIS, &service}); + acceptor->AddListener(GetFlag(FLAGS_port), new Listener{REDIS, &service}); acceptor->Run(); acceptor->Wait(); diff --git a/src/facade/reply_builder.cc b/src/facade/reply_builder.cc index 478e3d855fa6..cf6078d85bea 100644 --- a/src/facade/reply_builder.cc +++ b/src/facade/reply_builder.cc @@ -338,7 +338,8 @@ void SinkReplyBuilder2::NextVec(std::string_view str) { vecs_.push_back(iovec{const_cast(str.data()), str.size()}); } -MCReplyBuilder::MCReplyBuilder(::io::Sink* sink) : SinkReplyBuilder(sink, MC), noreply_(false) { +MCReplyBuilder::MCReplyBuilder(::io::Sink* sink) + : SinkReplyBuilder(sink, MEMCACHE), noreply_(false) { } void MCReplyBuilder::SendSimpleString(std::string_view str) { diff --git a/src/facade/reply_builder.h b/src/facade/reply_builder.h index 9cd120702ff1..177d405c874e 100644 --- a/src/facade/reply_builder.h +++ b/src/facade/reply_builder.h @@ -67,7 +67,7 @@ class SinkReplyBuilder { SinkReplyBuilder(const SinkReplyBuilder&) = delete; void operator=(const SinkReplyBuilder&) = delete; - enum Type { REDIS, MC }; + using Type = Protocol; explicit SinkReplyBuilder(::io::Sink* sink, Type t); diff --git a/src/facade/service_interface.h b/src/facade/service_interface.h index e881cba2c33d..aefc5d38c198 100644 --- a/src/facade/service_interface.h +++ b/src/facade/service_interface.h @@ -36,7 +36,7 @@ class ServiceInterface { virtual void DispatchMC(const MemcacheParser::Command& cmd, std::string_view value, MCReplyBuilder* builder, ConnectionContext* cntx) = 0; - virtual ConnectionContext* CreateContext(util::FiberSocketBase* peer, Connection* owner) = 0; + virtual ConnectionContext* CreateContext(Connection* owner) = 0; virtual void ConfigureHttpHandlers(util::HttpListenerBase* base, bool is_privileged) { } diff --git a/src/server/acl/acl_family.cc b/src/server/acl/acl_family.cc index 93a5796c83f6..82bbc6c203c1 100644 --- a/src/server/acl/acl_family.cc +++ b/src/server/acl/acl_family.cc @@ -27,6 +27,7 @@ #include "base/logging.h" #include "core/overloaded.h" #include "facade/dragonfly_connection.h" +#include "facade/dragonfly_listener.h" #include "facade/facade_types.h" #include "io/file.h" #include "io/file_util.h" @@ -102,14 +103,13 @@ void AclFamily::StreamUpdatesToAllProactorConnections(const std::string& user, auto update_cb = [&]([[maybe_unused]] size_t id, util::Connection* conn) { DCHECK(conn); auto connection = static_cast(conn); - if (connection->protocol() == facade::Protocol::REDIS && !connection->IsHttp() && - connection->cntx()) { + if (!connection->IsHttp() && connection->cntx()) { connection->SendAclUpdateAsync( facade::Connection::AclUpdateMessage{user, update_commands, update_keys, update_pub_sub}); } }; - if (main_listener_) { + if (main_listener_ && main_listener_->protocol() == facade::REDIS) { main_listener_->TraverseConnections(update_cb); } } diff --git a/src/server/acl/acl_family.h b/src/server/acl/acl_family.h index fc09faf30d64..97f925f71760 100644 --- a/src/server/acl/acl_family.h +++ b/src/server/acl/acl_family.h @@ -10,7 +10,6 @@ #include #include "absl/container/flat_hash_set.h" -#include "facade/dragonfly_listener.h" #include "facade/facade_types.h" #include "helio/util/proactor_pool.h" #include "server/acl/acl_commands_def.h" @@ -20,6 +19,7 @@ namespace facade { class SinkReplyBuilder; +class Listener; } // namespace facade namespace dfly { diff --git a/src/server/conn_context.cc b/src/server/conn_context.cc index 6cdd66f1a12b..7c1ed55f8bd9 100644 --- a/src/server/conn_context.cc +++ b/src/server/conn_context.cc @@ -93,9 +93,8 @@ const CommandId* StoredCmd::Cid() const { return cid_; } -ConnectionContext::ConnectionContext(::io::Sink* stream, facade::Connection* owner, - acl::UserCredentials cred) - : facade::ConnectionContext(stream, owner) { +ConnectionContext::ConnectionContext(facade::Connection* owner, acl::UserCredentials cred) + : facade::ConnectionContext(owner) { if (owner) { skip_acl_validation = owner->IsPrivileged(); } @@ -110,7 +109,7 @@ ConnectionContext::ConnectionContext(::io::Sink* stream, facade::Connection* own } ConnectionContext::ConnectionContext(const ConnectionContext* owner, Transaction* tx) - : facade::ConnectionContext(nullptr, nullptr), transaction{tx} { + : facade::ConnectionContext(nullptr), transaction{tx} { if (owner) { acl_commands = owner->acl_commands; keys = owner->keys; diff --git a/src/server/conn_context.h b/src/server/conn_context.h index dc8c50442179..731212546136 100644 --- a/src/server/conn_context.h +++ b/src/server/conn_context.h @@ -266,7 +266,7 @@ struct ConnectionState { class ConnectionContext : public facade::ConnectionContext { public: - ConnectionContext(::io::Sink* stream, facade::Connection* owner, dfly::acl::UserCredentials cred); + ConnectionContext(facade::Connection* owner, dfly::acl::UserCredentials cred); ConnectionContext(const ConnectionContext* owner, Transaction* tx); struct DebugInfo { diff --git a/src/server/dfly_main.cc b/src/server/dfly_main.cc index 5caa7cdc1bf9..25f95e0e58e4 100644 --- a/src/server/dfly_main.cc +++ b/src/server/dfly_main.cc @@ -184,7 +184,7 @@ bool RunEngine(ProactorPool* pool, AcceptServer* acceptor) { // we depend on tcp listener to be at the front since we later // need to pass it to the AclFamily::Init if (!tcp_disabled) { - auto listener = MakeListener(Protocol::REDIS, &service, Listener::Role::MAIN); + auto listener = MakeListener(REDIS, &service, Listener::Role::MAIN); main_listener = listener.get(); listeners.push_back(listener.release()); } @@ -232,7 +232,7 @@ bool RunEngine(ProactorPool* pool, AcceptServer* acceptor) { } unlink(unix_sock.c_str()); - auto uds_listener = MakeListener(Protocol::REDIS, &service); + auto uds_listener = MakeListener(REDIS, &service); error_code ec = acceptor->AddUDSListener(unix_sock.c_str(), unix_socket_perm, uds_listener.get()); if (ec) { @@ -262,7 +262,7 @@ bool RunEngine(ProactorPool* pool, AcceptServer* acceptor) { const char* interface_addr = admin_bind.empty() ? nullptr : admin_bind.c_str(); const std::string printable_addr = absl::StrCat("admin socket ", interface_addr ? interface_addr : "any", ":", admin_port); - auto admin_listener = MakeListener(Protocol::REDIS, &service, Listener::Role::PRIVILEGED); + auto admin_listener = MakeListener(REDIS, &service, Listener::Role::PRIVILEGED); error_code ec = acceptor->AddListener(interface_addr, admin_port, admin_listener.get()); @@ -288,7 +288,7 @@ bool RunEngine(ProactorPool* pool, AcceptServer* acceptor) { } if (mc_port > 0 && !tcp_disabled) { - auto listener = MakeListener(Protocol::MEMCACHE, &service); + auto listener = MakeListener(MEMCACHE, &service); acceptor->AddListener(mc_port, listener.get()); listeners.push_back(listener.release()); } diff --git a/src/server/generic_family.cc b/src/server/generic_family.cc index 390d70793a7d..2cec9fb9c57c 100644 --- a/src/server/generic_family.cc +++ b/src/server/generic_family.cc @@ -978,7 +978,7 @@ void GenericFamily::Del(CmdArgList args, Transaction* tx, SinkReplyBuilder* buil VLOG(1) << "Del " << ArgS(args, 0); atomic_uint32_t result{0}; - bool is_mc = (builder->type() == SinkReplyBuilder::MC); + bool is_mc = (builder->type() == MEMCACHE); auto cb = [&result](const Transaction* t, EngineShard* shard) { ShardArgs args = t->GetShardArgs(shard->shard_id()); diff --git a/src/server/main_service.cc b/src/server/main_service.cc index dd091d12b5ba..b68457e2c105 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1350,7 +1350,7 @@ class ReplyGuard { const bool is_script = bool(cntx->conn_state.script_info); const bool is_one_of = absl::flat_hash_set({"REPLCONF", "DFLY"}).contains(cid_name); - bool is_mcache = builder->type() == SinkReplyBuilder::MC; + bool is_mcache = builder->type() == MEMCACHE; const bool is_no_reply_memcache = is_mcache && (static_cast(builder)->NoReply() || cid_name == "QUIT"); const bool should_dcheck = !is_one_of && !is_script && !is_no_reply_memcache; @@ -1495,7 +1495,7 @@ size_t Service::DispatchManyCommands(absl::Span args_list, SinkReply facade::ConnectionContext* cntx) { ConnectionContext* dfly_cntx = static_cast(cntx); DCHECK(!dfly_cntx->conn_state.exec_info.IsRunning()); - DCHECK_EQ(builder->type(), SinkReplyBuilder::REDIS); + DCHECK_EQ(builder->type(), REDIS); vector stored_cmds; intrusive_ptr dist_trans; @@ -1691,13 +1691,12 @@ bool RequirePrivilegedAuth() { return !GetFlag(FLAGS_admin_nopass); } -facade::ConnectionContext* Service::CreateContext(util::FiberSocketBase* peer, - facade::Connection* owner) { +facade::ConnectionContext* Service::CreateContext(facade::Connection* owner) { auto cred = user_registry_.GetCredentials("default"); - ConnectionContext* res = new ConnectionContext{peer, owner, std::move(cred)}; + ConnectionContext* res = new ConnectionContext{owner, std::move(cred)}; res->ns = &namespaces.GetOrInsert(""); - if (peer->IsUDS()) { + if (owner->socket()->IsUDS()) { res->req_auth = false; res->skip_acl_validation = true; } else if (owner->IsPrivileged() && RequirePrivilegedAuth()) { @@ -1746,9 +1745,9 @@ absl::flat_hash_map Service::UknownCmdMap() const { void Service::Quit(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder, ConnectionContext* cntx) { - if (builder->type() == SinkReplyBuilder::REDIS) + if (builder->type() == REDIS) builder->SendOk(); - using facade::SinkReplyBuilder; + builder->CloseConnection(); DeactivateMonitoring(static_cast(cntx)); diff --git a/src/server/main_service.h b/src/server/main_service.h index 710bdf77383c..471184ffdf88 100644 --- a/src/server/main_service.h +++ b/src/server/main_service.h @@ -62,8 +62,7 @@ class Service : public facade::ServiceInterface { void DispatchMC(const MemcacheParser::Command& cmd, std::string_view value, facade::MCReplyBuilder* builder, facade::ConnectionContext* cntx) final; - facade::ConnectionContext* CreateContext(util::FiberSocketBase* peer, - facade::Connection* owner) final; + facade::ConnectionContext* CreateContext(facade::Connection* owner) final; const CommandId* FindCmd(std::string_view) const; diff --git a/src/server/replica.cc b/src/server/replica.cc index 58fb53bb7d4e..0457c471c403 100644 --- a/src/server/replica.cc +++ b/src/server/replica.cc @@ -593,7 +593,7 @@ error_code Replica::InitiateDflySync() { error_code Replica::ConsumeRedisStream() { base::IoBuf io_buf(16_KB); - ConnectionContext conn_context{static_cast(nullptr), nullptr, {}}; + ConnectionContext conn_context{nullptr, {}}; conn_context.is_replicating = true; conn_context.journal_emulated = true; conn_context.skip_acl_validation = true; diff --git a/src/server/string_family.cc b/src/server/string_family.cc index 6c5b21d2596c..5779f188e45c 100644 --- a/src/server/string_family.cc +++ b/src/server/string_family.cc @@ -650,7 +650,7 @@ void ExtendGeneric(CmdArgList args, bool prepend, Transaction* tx, SinkReplyBuil string_view value = ArgS(args, 1); VLOG(2) << "ExtendGeneric(" << key << ", " << value << ")"; - if (builder->type() == SinkReplyBuilder::REDIS) { + if (builder->type() == REDIS) { auto cb = [&](Transaction* t, EngineShard* shard) { return OpExtend(t->GetOpArgs(shard), key, value, prepend); }; @@ -662,7 +662,7 @@ void ExtendGeneric(CmdArgList args, bool prepend, Transaction* tx, SinkReplyBuil rb->SendLong(GetResult(std::move(res.value()))); } else { // Memcached skips if key is missing - DCHECK(builder->type() == SinkReplyBuilder::MC); + DCHECK(builder->type() == MEMCACHE); auto cb = [&](Transaction* t, EngineShard* shard) { return ExtendOrSkip(t->GetOpArgs(shard), key, value, prepend); @@ -723,7 +723,7 @@ void SetExGeneric(bool seconds, CmdArgList args, const CommandId* cid, Transacti } void IncrByGeneric(string_view key, int64_t val, Transaction* tx, SinkReplyBuilder* builder) { - bool skip_on_missing = builder->type() == SinkReplyBuilder::MC; + bool skip_on_missing = builder->type() == MEMCACHE; auto cb = [&](Transaction* t, EngineShard* shard) { OpResult res = OpIncrBy(t->GetOpArgs(shard), key, val, skip_on_missing); @@ -1256,7 +1256,7 @@ void StringFamily::MGet(CmdArgList args, Transaction* tx, SinkReplyBuilder* buil std::vector mget_resp(shard_set->size()); uint8_t fetch_mask = 0; - if (builder->type() == SinkReplyBuilder::MC) { + if (builder->type() == MEMCACHE) { fetch_mask |= FETCH_MCFLAG; if (cntx->conn_state.memcache_flag & ConnectionState::FETCH_CAS_VER) fetch_mask |= FETCH_MCVER; @@ -1296,7 +1296,7 @@ void StringFamily::MGet(CmdArgList args, Transaction* tx, SinkReplyBuilder* buil uint32_t indx = it.index(); res.resp_arr[indx] = std::move(src.resp_arr[src_indx]); - if (builder->type() == SinkReplyBuilder::MC) { + if (builder->type() == MEMCACHE) { res.resp_arr[indx]->key = *it; } } diff --git a/src/server/test_utils.cc b/src/server/test_utils.cc index d86d8404f60f..51b3935bbdf4 100644 --- a/src/server/test_utils.cc +++ b/src/server/test_utils.cc @@ -59,7 +59,7 @@ static vector SplitLines(const std::string& src) { TestConnection::TestConnection(Protocol protocol, io::StringSink* sink) : facade::Connection(protocol, nullptr, nullptr, nullptr), sink_(sink) { - cc_.reset(new dfly::ConnectionContext(sink_, this, {})); + cc_.reset(new dfly::ConnectionContext(this, {})); cc_->skip_acl_validation = true; SetSocket(ProactorBase::me()->CreateSocket()); OnConnectionStart(); @@ -125,6 +125,10 @@ class BaseFamilyTest::TestConnWrapper { return dummy_conn_.get(); } + SinkReplyBuilder* builder() { + return builder_.get(); + } + private: ::io::StringSink sink_; // holds the response blob @@ -133,10 +137,21 @@ class BaseFamilyTest::TestConnWrapper { std::vector> tmp_str_vec_; std::unique_ptr parser_; + std::unique_ptr builder_; }; BaseFamilyTest::TestConnWrapper::TestConnWrapper(Protocol proto) : dummy_conn_(new TestConnection(proto, &sink_)) { + switch (proto) { + case REDIS: + builder_.reset(new RedisReplyBuilder{&sink_}); + break; + case MEMCACHE: + builder_.reset(new MCReplyBuilder{&sink_}); + break; + default: + LOG(FATAL) << "Unknown protocol"; + } } BaseFamilyTest::TestConnWrapper::~TestConnWrapper() { @@ -357,7 +372,7 @@ RespExpr BaseFamilyTest::RunPrivileged(std::initializer_listat(0)->Await([&] { return this->RunPrivileged(list); }); } string id = GetId(); - TestConnWrapper* conn_wrapper = AddFindConn(Protocol::REDIS, id); + TestConnWrapper* conn_wrapper = AddFindConn(facade::REDIS, id); // Before running the command set the connection as admin connection conn_wrapper->conn()->SetPrivileged(true); auto res = Run(id, ArgSlice{list.begin(), list.size()}); @@ -381,7 +396,7 @@ RespExpr BaseFamilyTest::Run(std::string_view id, ArgSlice slice) { return pp_->at(0)->Await([&] { return this->Run(id, slice); }); } - TestConnWrapper* conn_wrapper = AddFindConn(Protocol::REDIS, id); + TestConnWrapper* conn_wrapper = AddFindConn(facade::REDIS, id); CmdArgVec args = conn_wrapper->Args(slice); @@ -390,7 +405,7 @@ RespExpr BaseFamilyTest::Run(std::string_view id, ArgSlice slice) { DCHECK(context->transaction == nullptr) << id; - service_->DispatchCommand(CmdArgList{args}, context->reply_builder_old(), context); + service_->DispatchCommand(CmdArgList{args}, conn_wrapper->builder(), context); DCHECK(context->transaction == nullptr); @@ -427,14 +442,13 @@ auto BaseFamilyTest::RunMC(MP::CmdType cmd_type, string_view key, string_view va cmd.bytes_len = value.size(); cmd.expire_ts = ttl.count(); - TestConnWrapper* conn = AddFindConn(Protocol::MEMCACHE, GetId()); + TestConnWrapper* conn = AddFindConn(facade::MEMCACHE, GetId()); auto* context = conn->cmd_cntx(); DCHECK(context->transaction == nullptr); - service_->DispatchMC(cmd, value, static_cast(context->reply_builder_old()), - context); + service_->DispatchMC(cmd, value, static_cast(conn->builder()), context); DCHECK(context->transaction == nullptr); @@ -446,17 +460,7 @@ auto BaseFamilyTest::RunMC(MP::CmdType cmd_type, std::string_view key) -> MCResp return pp_->at(0)->Await([&] { return this->RunMC(cmd_type, key); }); } - MP::Command cmd; - cmd.type = cmd_type; - cmd.key = key; - TestConnWrapper* conn = AddFindConn(Protocol::MEMCACHE, GetId()); - - auto* context = conn->cmd_cntx(); - - service_->DispatchMC(cmd, string_view{}, - static_cast(context->reply_builder_old()), context); - - return conn->SplitLines(); + return RunMC(cmd_type, key, string_view{}, 0, chrono::seconds{}); } auto BaseFamilyTest::GetMC(MP::CmdType cmd_type, std::initializer_list list) @@ -476,12 +480,10 @@ auto BaseFamilyTest::GetMC(MP::CmdType cmd_type, std::initializer_listcmd_cntx(); - - service_->DispatchMC(cmd, string_view{}, - static_cast(context->reply_builder_old()), context); + service_->DispatchMC(cmd, string_view{}, static_cast(conn->builder()), context); return conn->SplitLines(); }