diff --git a/src/facade/conn_context.cc b/src/facade/conn_context.cc index 4753def33f29..9879383765fe 100644 --- a/src/facade/conn_context.cc +++ b/src/facade/conn_context.cc @@ -12,26 +12,7 @@ namespace facade { -ConnectionContext::ConnectionContext(::io::Sink* stream, Connection* owner) : owner_(owner) { - if (owner) { - protocol_ = owner->protocol(); - } - - if (stream) { - switch (protocol_) { - case Protocol::NONE: - LOG(DFATAL) << "Invalid protocol"; - break; - case Protocol::REDIS: { - rbuilder_.reset(new RedisReplyBuilder(stream)); - break; - } - case Protocol::MEMCACHE: - rbuilder_.reset(new MCReplyBuilder(stream)); - break; - } - } - +ConnectionContext::ConnectionContext(Connection* owner) : owner_(owner) { conn_closing = false; req_auth = false; replica_conn = false; @@ -46,7 +27,7 @@ ConnectionContext::ConnectionContext(::io::Sink* stream, Connection* owner) : ow } size_t ConnectionContext::UsedMemory() const { - return dfly::HeapSize(rbuilder_) + dfly::HeapSize(authed_username) + dfly::HeapSize(acl_commands); + return dfly::HeapSize(authed_username) + dfly::HeapSize(acl_commands); } } // namespace facade diff --git a/src/facade/conn_context.h b/src/facade/conn_context.h index 70f82e403e13..b87e901b473d 100644 --- a/src/facade/conn_context.h +++ b/src/facade/conn_context.h @@ -19,7 +19,7 @@ class Connection; class ConnectionContext { public: - ConnectionContext(::io::Sink* stream, Connection* owner); + explicit ConnectionContext(Connection* owner); virtual ~ConnectionContext() { } @@ -32,14 +32,6 @@ class ConnectionContext { return owner_; } - Protocol protocol() const { - return protocol_; - } - - SinkReplyBuilder* reply_builder_old() { - return rbuilder_.get(); - } - virtual size_t UsedMemory() const; // connection state / properties. @@ -71,8 +63,6 @@ class ConnectionContext { private: Connection* owner_; - Protocol protocol_ = Protocol::REDIS; - std::unique_ptr rbuilder_; }; } // namespace facade diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index e1508700efc1..eac75dcf3c38 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -487,14 +487,15 @@ void Connection::DispatchOperations::operator()(Connection::PipelineMessage& msg DVLOG(2) << "Dispatching pipeline: " << ToSV(msg.args.front()); self->service_->DispatchCommand(CmdArgList{msg.args.data(), msg.args.size()}, - self->reply_builder_, self->cc_.get()); + self->reply_builder_.get(), self->cc_.get()); self->last_interaction_ = time(nullptr); self->skip_next_squashing_ = false; } void Connection::DispatchOperations::operator()(const Connection::MCPipelineMessage& msg) { - self->service_->DispatchMC(msg.cmd, msg.value, static_cast(self->reply_builder_), + self->service_->DispatchMC(msg.cmd, msg.value, + static_cast(self->reply_builder_.get()), self->cc_.get()); self->last_interaction_ = time(nullptr); } @@ -538,14 +539,13 @@ void UpdateLibNameVerMap(const string& name, const string& ver, int delta) { Connection::Connection(Protocol protocol, util::HttpListenerBase* http_listener, SSL_CTX* ctx, ServiceInterface* service) : io_buf_(kMinReadSize), + protocol_(protocol), http_listener_(http_listener), ssl_ctx_(ctx), service_(service), flags_(0) { static atomic_uint32_t next_id{1}; - protocol_ = protocol; - constexpr size_t kReqSz = sizeof(Connection::PipelineMessage); static_assert(kReqSz <= 256 && kReqSz >= 200); @@ -727,8 +727,7 @@ void Connection::HandleRequests() { // because both Write and Recv internally check if the socket was shut // down and return with an error accordingly. if (http_res && socket_->IsOpen()) { - cc_.reset(service_->CreateContext(socket_.get(), this)); - reply_builder_ = cc_->reply_builder_old(); + cc_.reset(service_->CreateContext(this)); if (*http_res) { VLOG(1) << "HTTP1.1 identified"; @@ -748,19 +747,28 @@ void Connection::HandleRequests() { // Release the ownership of the socket from http_conn so it would stay with // this connection. http_conn.ReleaseSocket(); - } else { + } else { // non-http if (breaker_cb_) { socket_->RegisterOnErrorCb([this](int32_t mask) { this->OnBreakCb(mask); }); } - + switch (protocol_) { + case Protocol::REDIS: + reply_builder_.reset(new RedisReplyBuilder(socket_.get())); + break; + case Protocol::MEMCACHE: + reply_builder_.reset(new MCReplyBuilder(socket_.get())); + break; + default: + break; + } ConnectionFlow(); socket_->CancelOnErrorCb(); // noop if nothing is registered. + VLOG(1) << "Closed connection for peer " + << GetClientInfo(fb2::ProactorBase::me()->GetPoolIndex()); + reply_builder_.reset(); } - VLOG(1) << "Closed connection for peer " - << GetClientInfo(fb2::ProactorBase::me()->GetPoolIndex()); cc_.reset(); - reply_builder_ = nullptr; } } @@ -932,6 +940,8 @@ io::Result Connection::CheckForHttpProto() { } void Connection::ConnectionFlow() { + DCHECK(reply_builder_); + ++stats_->num_conns; ++stats_->conn_received_cnt; stats_->read_buf_capacity += io_buf_.Capacity(); @@ -989,7 +999,7 @@ void Connection::ConnectionFlow() { VLOG(1) << "Error parser status " << parser_error_; if (redis_parser_) { - SendProtocolError(RedisParser::Result(parser_error_), reply_builder_); + SendProtocolError(RedisParser::Result(parser_error_), reply_builder_.get()); } else { DCHECK(memcache_parser_); reply_builder_->SendProtocolError("bad command line format"); @@ -1092,7 +1102,7 @@ Connection::ParserStatus Connection::ParseRedis() { auto dispatch_sync = [this, &parse_args, &cmd_vec] { RespExpr::VecToArgList(parse_args, &cmd_vec); - service_->DispatchCommand(absl::MakeSpan(cmd_vec), reply_builder_, cc_.get()); + service_->DispatchCommand(absl::MakeSpan(cmd_vec), reply_builder_.get(), cc_.get()); }; auto dispatch_async = [this, &parse_args, tlh = mi_heap_get_backing()]() -> MessageHandle { return {FromArgs(std::move(parse_args), tlh)}; @@ -1137,14 +1147,14 @@ auto Connection::ParseMemcache() -> ParserStatus { string_view value; auto dispatch_sync = [this, &cmd, &value] { - service_->DispatchMC(cmd, value, static_cast(reply_builder_), cc_.get()); + service_->DispatchMC(cmd, value, static_cast(reply_builder_.get()), cc_.get()); }; auto dispatch_async = [&cmd, &value]() -> MessageHandle { return {make_unique(std::move(cmd), value)}; }; - MCReplyBuilder* builder = static_cast(reply_builder_); + MCReplyBuilder* builder = static_cast(reply_builder_.get()); do { string_view str = ToSV(io_buf_.InputBuffer()); @@ -1377,7 +1387,7 @@ void Connection::SquashPipeline() { cc_->async_dispatch = true; size_t dispatched = - service_->DispatchManyCommands(absl::MakeSpan(squash_cmds), reply_builder_, cc_.get()); + service_->DispatchManyCommands(absl::MakeSpan(squash_cmds), reply_builder_.get(), cc_.get()); if (pending_pipeline_cmd_cnt_ == squash_cmds.size()) { // Flush if no new commands appeared reply_builder_->Flush(); @@ -1400,7 +1410,7 @@ void Connection::SquashPipeline() { } void Connection::ClearPipelinedMessages() { - DispatchOperations dispatch_op{reply_builder_, this}; + DispatchOperations dispatch_op{reply_builder_.get(), this}; // Recycle messages even from disconnecting client to keep properly track of memory stats // As well as to avoid pubsub backpressure leakege. @@ -1448,7 +1458,7 @@ std::string Connection::DebugInfo() const { void Connection::ExecutionFiber() { ThisFiber::SetName("ExecutionFiber"); - DispatchOperations dispatch_op{reply_builder_, this}; + DispatchOperations dispatch_op{reply_builder_.get(), this}; size_t squashing_threshold = GetFlag(FLAGS_pipeline_squash); @@ -1812,7 +1822,7 @@ Connection::MemoryUsage Connection::GetMemoryUsage() const { size_t mem = sizeof(*this) + dfly::HeapSize(dispatch_q_) + dfly::HeapSize(name_) + dfly::HeapSize(tmp_parse_args_) + dfly::HeapSize(tmp_cmd_vec_) + dfly::HeapSize(memcache_parser_) + dfly::HeapSize(redis_parser_) + - dfly::HeapSize(cc_); + dfly::HeapSize(cc_) + dfly::HeapSize(reply_builder_); // We add a hardcoded 9k value to accomodate for the part of the Fiber stack that is in use. // The allocated stack is actually larger (~130k), but only a small fraction of that (9k diff --git a/src/facade/dragonfly_connection.h b/src/facade/dragonfly_connection.h index bc1af66c30b2..401bbfe14ac8 100644 --- a/src/facade/dragonfly_connection.h +++ b/src/facade/dragonfly_connection.h @@ -269,10 +269,6 @@ class Connection : public util::Connection { bool IsMain() const; - Protocol protocol() const { - return protocol_; - } - void SetName(std::string name); void SetLibName(std::string name); @@ -404,9 +400,7 @@ class Connection : public util::Connection { Protocol protocol_; ConnectionStats* stats_ = nullptr; - // cc_->reply_builder may change during the lifetime of the connection, due to injections. - // This is a pointer to the original, socket based reply builder that never changes. - SinkReplyBuilder* reply_builder_ = nullptr; + std::unique_ptr reply_builder_; util::HttpListenerBase* http_listener_; SSL_CTX* ssl_ctx_; diff --git a/src/facade/dragonfly_listener.h b/src/facade/dragonfly_listener.h index 7ad8bc0e39e8..106884424e81 100644 --- a/src/facade/dragonfly_listener.h +++ b/src/facade/dragonfly_listener.h @@ -50,6 +50,10 @@ class Listener : public util::ListenerInterface { bool IsPrivilegedInterface() const; bool IsMainInterface() const; + Protocol protocol() const { + return protocol_; + } + private: util::Connection* NewConnection(ProactorBase* proactor) final; ProactorBase* PickConnectionProactor(util::FiberSocketBase* sock) final; diff --git a/src/facade/facade.cc b/src/facade/facade.cc index cce96aa75866..9261bf3c6e75 100644 --- a/src/facade/facade.cc +++ b/src/facade/facade.cc @@ -165,10 +165,6 @@ ostream& operator<<(ostream& os, facade::CmdArgList ras) { return os; } -ostream& operator<<(ostream& os, facade::Protocol p) { - return os << int(p); -} - ostream& operator<<(ostream& os, const facade::RespExpr& e) { using facade::RespExpr; using facade::ToSV; @@ -213,4 +209,19 @@ ostream& operator<<(ostream& os, facade::RespSpan ras) { return os; } +ostream& operator<<(ostream& os, facade::Protocol p) { + switch (p) { + case facade::Protocol::REDIS: + os << "REDIS"; + break; + case facade::Protocol::MEMCACHE: + os << "MEMCACHE"; + break; + default: + os << "NONE"; + } + + return os; +} + } // namespace std diff --git a/src/facade/ok_main.cc b/src/facade/ok_main.cc index 0b8b06823e2c..7ff9e4d62b37 100644 --- a/src/facade/ok_main.cc +++ b/src/facade/ok_main.cc @@ -37,8 +37,8 @@ class OkService : public ServiceInterface { builder->SendError(""); } - ConnectionContext* CreateContext(util::FiberSocketBase* peer, Connection* owner) final { - return new ConnectionContext{peer, owner}; + ConnectionContext* CreateContext(Connection* owner) final { + return new ConnectionContext{owner}; } }; diff --git a/src/facade/service_interface.h b/src/facade/service_interface.h index e881cba2c33d..aefc5d38c198 100644 --- a/src/facade/service_interface.h +++ b/src/facade/service_interface.h @@ -36,7 +36,7 @@ class ServiceInterface { virtual void DispatchMC(const MemcacheParser::Command& cmd, std::string_view value, MCReplyBuilder* builder, ConnectionContext* cntx) = 0; - virtual ConnectionContext* CreateContext(util::FiberSocketBase* peer, Connection* owner) = 0; + virtual ConnectionContext* CreateContext(Connection* owner) = 0; virtual void ConfigureHttpHandlers(util::HttpListenerBase* base, bool is_privileged) { } diff --git a/src/server/acl/acl_family.cc b/src/server/acl/acl_family.cc index 93a5796c83f6..44250cb24251 100644 --- a/src/server/acl/acl_family.cc +++ b/src/server/acl/acl_family.cc @@ -27,6 +27,7 @@ #include "base/logging.h" #include "core/overloaded.h" #include "facade/dragonfly_connection.h" +#include "facade/dragonfly_listener.h" #include "facade/facade_types.h" #include "io/file.h" #include "io/file_util.h" @@ -102,14 +103,13 @@ void AclFamily::StreamUpdatesToAllProactorConnections(const std::string& user, auto update_cb = [&]([[maybe_unused]] size_t id, util::Connection* conn) { DCHECK(conn); auto connection = static_cast(conn); - if (connection->protocol() == facade::Protocol::REDIS && !connection->IsHttp() && - connection->cntx()) { + if (!connection->IsHttp() && connection->cntx()) { connection->SendAclUpdateAsync( facade::Connection::AclUpdateMessage{user, update_commands, update_keys, update_pub_sub}); } }; - if (main_listener_) { + if (main_listener_ && main_listener_->protocol() == facade::Protocol::REDIS) { main_listener_->TraverseConnections(update_cb); } } diff --git a/src/server/acl/acl_family.h b/src/server/acl/acl_family.h index fc09faf30d64..97f925f71760 100644 --- a/src/server/acl/acl_family.h +++ b/src/server/acl/acl_family.h @@ -10,7 +10,6 @@ #include #include "absl/container/flat_hash_set.h" -#include "facade/dragonfly_listener.h" #include "facade/facade_types.h" #include "helio/util/proactor_pool.h" #include "server/acl/acl_commands_def.h" @@ -20,6 +19,7 @@ namespace facade { class SinkReplyBuilder; +class Listener; } // namespace facade namespace dfly { diff --git a/src/server/conn_context.cc b/src/server/conn_context.cc index 40652c34e00d..034d0292a7ab 100644 --- a/src/server/conn_context.cc +++ b/src/server/conn_context.cc @@ -100,9 +100,8 @@ const CommandId* StoredCmd::Cid() const { return cid_; } -ConnectionContext::ConnectionContext(::io::Sink* stream, facade::Connection* owner, - acl::UserCredentials cred) - : facade::ConnectionContext(stream, owner) { +ConnectionContext::ConnectionContext(facade::Connection* owner, acl::UserCredentials cred) + : facade::ConnectionContext(owner) { if (owner) { skip_acl_validation = owner->IsPrivileged(); } @@ -117,7 +116,7 @@ ConnectionContext::ConnectionContext(::io::Sink* stream, facade::Connection* own } ConnectionContext::ConnectionContext(const ConnectionContext* owner, Transaction* tx) - : facade::ConnectionContext(nullptr, nullptr), transaction{tx} { + : facade::ConnectionContext(nullptr), transaction{tx} { if (owner) { acl_commands = owner->acl_commands; keys = owner->keys; diff --git a/src/server/conn_context.h b/src/server/conn_context.h index 58e586e6e789..79b7588fb8b6 100644 --- a/src/server/conn_context.h +++ b/src/server/conn_context.h @@ -268,7 +268,7 @@ struct ConnectionState { class ConnectionContext : public facade::ConnectionContext { public: - ConnectionContext(::io::Sink* stream, facade::Connection* owner, dfly::acl::UserCredentials cred); + ConnectionContext(facade::Connection* owner, dfly::acl::UserCredentials cred); ConnectionContext(const ConnectionContext* owner, Transaction* tx); struct DebugInfo { diff --git a/src/server/main_service.cc b/src/server/main_service.cc index ee0cf4846369..c5664d1a25b2 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1581,13 +1581,12 @@ bool RequirePrivilegedAuth() { return !GetFlag(FLAGS_admin_nopass); } -facade::ConnectionContext* Service::CreateContext(util::FiberSocketBase* peer, - facade::Connection* owner) { +facade::ConnectionContext* Service::CreateContext(facade::Connection* owner) { auto cred = user_registry_.GetCredentials("default"); - ConnectionContext* res = new ConnectionContext{peer, owner, std::move(cred)}; + ConnectionContext* res = new ConnectionContext{owner, std::move(cred)}; res->ns = &namespaces->GetOrInsert(""); - if (peer->IsUDS()) { + if (owner->socket()->IsUDS()) { res->req_auth = false; res->skip_acl_validation = true; } else if (owner->IsPrivileged() && RequirePrivilegedAuth()) { @@ -1638,7 +1637,7 @@ void Service::Quit(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder, ConnectionContext* cntx) { if (builder->GetProtocol() == Protocol::REDIS) builder->SendOk(); - using facade::SinkReplyBuilder; + builder->CloseConnection(); DeactivateMonitoring(static_cast(cntx)); diff --git a/src/server/main_service.h b/src/server/main_service.h index 710bdf77383c..471184ffdf88 100644 --- a/src/server/main_service.h +++ b/src/server/main_service.h @@ -62,8 +62,7 @@ class Service : public facade::ServiceInterface { void DispatchMC(const MemcacheParser::Command& cmd, std::string_view value, facade::MCReplyBuilder* builder, facade::ConnectionContext* cntx) final; - facade::ConnectionContext* CreateContext(util::FiberSocketBase* peer, - facade::Connection* owner) final; + facade::ConnectionContext* CreateContext(facade::Connection* owner) final; const CommandId* FindCmd(std::string_view) const; diff --git a/src/server/replica.cc b/src/server/replica.cc index a33a49a0440e..928ad610406e 100644 --- a/src/server/replica.cc +++ b/src/server/replica.cc @@ -593,7 +593,7 @@ error_code Replica::InitiateDflySync() { error_code Replica::ConsumeRedisStream() { base::IoBuf io_buf(16_KB); - ConnectionContext conn_context{static_cast(nullptr), nullptr, {}}; + ConnectionContext conn_context{nullptr, {}}; conn_context.is_replicating = true; conn_context.journal_emulated = true; conn_context.skip_acl_validation = true; diff --git a/src/server/string_family.cc b/src/server/string_family.cc index 2799762f7d60..22bbb65e1e9a 100644 --- a/src/server/string_family.cc +++ b/src/server/string_family.cc @@ -677,7 +677,7 @@ void ExtendGeneric(CmdArgList args, bool prepend, Transaction* tx, SinkReplyBuil rb->SendLong(GetResult(std::move(res.value()))); } else { // Memcached skips if key is missing - DCHECK(dynamic_cast(builder)); + DCHECK(builder->GetProtocol() == Protocol::MEMCACHE); auto cb = [&](Transaction* t, EngineShard* shard) { return ExtendOrSkip(t->GetOpArgs(shard), key, value, prepend); @@ -1588,7 +1588,7 @@ void StringFamily::Register(CommandRegistry* registry) { << CI{"SUBSTR", CO::READONLY, 4, 1, 1}.HFUNC(GetRange) // Alias for GetRange << CI{"SETRANGE", CO::WRITE | CO::DENYOOM, 4, 1, 1}.HFUNC(SetRange) << CI{"CL.THROTTLE", CO::WRITE | CO::DENYOOM | CO::FAST, -5, 1, 1, acl::THROTTLE}.HFUNC( - ClThrottle); + ClThrottle); } } // namespace dfly diff --git a/src/server/test_utils.cc b/src/server/test_utils.cc index f1a4e779151d..cf256aa95475 100644 --- a/src/server/test_utils.cc +++ b/src/server/test_utils.cc @@ -59,7 +59,7 @@ static vector SplitLines(const std::string& src) { TestConnection::TestConnection(Protocol protocol, io::StringSink* sink) : facade::Connection(protocol, nullptr, nullptr, nullptr), sink_(sink) { - cc_.reset(new dfly::ConnectionContext(sink_, this, {})); + cc_.reset(new dfly::ConnectionContext(this, {})); cc_->skip_acl_validation = true; SetSocket(ProactorBase::me()->CreateSocket()); OnConnectionStart(); @@ -125,6 +125,10 @@ class BaseFamilyTest::TestConnWrapper { return dummy_conn_.get(); } + SinkReplyBuilder* builder() { + return builder_.get(); + } + private: ::io::StringSink sink_; // holds the response blob @@ -133,10 +137,21 @@ class BaseFamilyTest::TestConnWrapper { std::vector> tmp_str_vec_; std::unique_ptr parser_; + std::unique_ptr builder_; }; BaseFamilyTest::TestConnWrapper::TestConnWrapper(Protocol proto) : dummy_conn_(new TestConnection(proto, &sink_)) { + switch (proto) { + case Protocol::REDIS: + builder_.reset(new RedisReplyBuilder{&sink_}); + break; + case Protocol::MEMCACHE: + builder_.reset(new MCReplyBuilder{&sink_}); + break; + default: + LOG(FATAL) << "Unknown protocol"; + } } BaseFamilyTest::TestConnWrapper::~TestConnWrapper() { @@ -390,7 +405,7 @@ RespExpr BaseFamilyTest::Run(std::string_view id, ArgSlice slice) { DCHECK(context->transaction == nullptr) << id; - service_->DispatchCommand(CmdArgList{args}, context->reply_builder_old(), context); + service_->DispatchCommand(CmdArgList{args}, conn_wrapper->builder(), context); DCHECK(context->transaction == nullptr); @@ -433,8 +448,7 @@ auto BaseFamilyTest::RunMC(MP::CmdType cmd_type, string_view key, string_view va DCHECK(context->transaction == nullptr); - service_->DispatchMC(cmd, value, static_cast(context->reply_builder_old()), - context); + service_->DispatchMC(cmd, value, static_cast(conn->builder()), context); DCHECK(context->transaction == nullptr); @@ -446,17 +460,7 @@ auto BaseFamilyTest::RunMC(MP::CmdType cmd_type, std::string_view key) -> MCResp return pp_->at(0)->Await([&] { return this->RunMC(cmd_type, key); }); } - MP::Command cmd; - cmd.type = cmd_type; - cmd.key = key; - TestConnWrapper* conn = AddFindConn(Protocol::MEMCACHE, GetId()); - - auto* context = conn->cmd_cntx(); - - service_->DispatchMC(cmd, string_view{}, - static_cast(context->reply_builder_old()), context); - - return conn->SplitLines(); + return RunMC(cmd_type, key, string_view{}, 0, chrono::seconds{}); } auto BaseFamilyTest::GetMC(MP::CmdType cmd_type, std::initializer_list list) @@ -479,9 +483,7 @@ auto BaseFamilyTest::GetMC(MP::CmdType cmd_type, std::initializer_listcmd_cntx(); - - service_->DispatchMC(cmd, string_view{}, - static_cast(context->reply_builder_old()), context); + service_->DispatchMC(cmd, string_view{}, static_cast(conn->builder()), context); return conn->SplitLines(); }