Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: decouple reply_builder from ConnectionContext #4069

Merged
merged 1 commit into from
Nov 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 2 additions & 21 deletions src/facade/conn_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,7 @@

namespace facade {

ConnectionContext::ConnectionContext(::io::Sink* stream, Connection* owner) : owner_(owner) {
if (owner) {
protocol_ = owner->protocol();
}

if (stream) {
switch (protocol_) {
case Protocol::NONE:
LOG(DFATAL) << "Invalid protocol";
break;
case Protocol::REDIS: {
rbuilder_.reset(new RedisReplyBuilder(stream));
break;
}
case Protocol::MEMCACHE:
rbuilder_.reset(new MCReplyBuilder(stream));
break;
}
}

ConnectionContext::ConnectionContext(Connection* owner) : owner_(owner) {
conn_closing = false;
req_auth = false;
replica_conn = false;
Expand All @@ -46,7 +27,7 @@ ConnectionContext::ConnectionContext(::io::Sink* stream, Connection* owner) : ow
}

size_t ConnectionContext::UsedMemory() const {
return dfly::HeapSize(rbuilder_) + dfly::HeapSize(authed_username) + dfly::HeapSize(acl_commands);
return dfly::HeapSize(authed_username) + dfly::HeapSize(acl_commands);
}

} // namespace facade
12 changes: 1 addition & 11 deletions src/facade/conn_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Connection;

class ConnectionContext {
public:
ConnectionContext(::io::Sink* stream, Connection* owner);
explicit ConnectionContext(Connection* owner);

virtual ~ConnectionContext() {
}
Expand All @@ -32,14 +32,6 @@ class ConnectionContext {
return owner_;
}

Protocol protocol() const {
return protocol_;
}

SinkReplyBuilder* reply_builder_old() {
return rbuilder_.get();
}

virtual size_t UsedMemory() const;

// connection state / properties.
Expand Down Expand Up @@ -71,8 +63,6 @@ class ConnectionContext {

private:
Connection* owner_;
Protocol protocol_ = Protocol::REDIS;
std::unique_ptr<SinkReplyBuilder> rbuilder_;
};

} // namespace facade
51 changes: 29 additions & 22 deletions src/facade/dragonfly_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -487,14 +487,15 @@ void Connection::DispatchOperations::operator()(Connection::PipelineMessage& msg
DVLOG(2) << "Dispatching pipeline: " << ToSV(msg.args.front());

self->service_->DispatchCommand(CmdArgList{msg.args.data(), msg.args.size()},
self->reply_builder_, self->cc_.get());
self->reply_builder_.get(), self->cc_.get());

self->last_interaction_ = time(nullptr);
self->skip_next_squashing_ = false;
}

void Connection::DispatchOperations::operator()(const Connection::MCPipelineMessage& msg) {
self->service_->DispatchMC(msg.cmd, msg.value, static_cast<MCReplyBuilder*>(self->reply_builder_),
self->service_->DispatchMC(msg.cmd, msg.value,
static_cast<MCReplyBuilder*>(self->reply_builder_.get()),
self->cc_.get());
self->last_interaction_ = time(nullptr);
}
Expand Down Expand Up @@ -538,21 +539,17 @@ void UpdateLibNameVerMap(const string& name, const string& ver, int delta) {
Connection::Connection(Protocol protocol, util::HttpListenerBase* http_listener, SSL_CTX* ctx,
ServiceInterface* service)
: io_buf_(kMinReadSize),
protocol_(protocol),
http_listener_(http_listener),
ssl_ctx_(ctx),
service_(service),
flags_(0) {
static atomic_uint32_t next_id{1};

protocol_ = protocol;

constexpr size_t kReqSz = sizeof(Connection::PipelineMessage);
static_assert(kReqSz <= 256 && kReqSz >= 200);

switch (protocol) {
case Protocol::NONE:
LOG(DFATAL) << "Invalid protocol";
break;
case Protocol::REDIS:
redis_parser_.reset(new RedisParser(GetFlag(FLAGS_max_multi_bulk_len)));
break;
Expand Down Expand Up @@ -727,8 +724,7 @@ void Connection::HandleRequests() {
// because both Write and Recv internally check if the socket was shut
// down and return with an error accordingly.
if (http_res && socket_->IsOpen()) {
cc_.reset(service_->CreateContext(socket_.get(), this));
reply_builder_ = cc_->reply_builder_old();
cc_.reset(service_->CreateContext(this));

if (*http_res) {
VLOG(1) << "HTTP1.1 identified";
Expand All @@ -748,19 +744,28 @@ void Connection::HandleRequests() {
// Release the ownership of the socket from http_conn so it would stay with
// this connection.
http_conn.ReleaseSocket();
} else {
} else { // non-http
if (breaker_cb_) {
socket_->RegisterOnErrorCb([this](int32_t mask) { this->OnBreakCb(mask); });
}

switch (protocol_) {
case Protocol::REDIS:
reply_builder_.reset(new RedisReplyBuilder(socket_.get()));
break;
case Protocol::MEMCACHE:
reply_builder_.reset(new MCReplyBuilder(socket_.get()));
break;
default:
break;
}
ConnectionFlow();

socket_->CancelOnErrorCb(); // noop if nothing is registered.
VLOG(1) << "Closed connection for peer "
<< GetClientInfo(fb2::ProactorBase::me()->GetPoolIndex());
reply_builder_.reset();
}
VLOG(1) << "Closed connection for peer "
<< GetClientInfo(fb2::ProactorBase::me()->GetPoolIndex());
cc_.reset();
reply_builder_ = nullptr;
}
}

Expand Down Expand Up @@ -932,6 +937,8 @@ io::Result<bool> Connection::CheckForHttpProto() {
}

void Connection::ConnectionFlow() {
DCHECK(reply_builder_);

++stats_->num_conns;
++stats_->conn_received_cnt;
stats_->read_buf_capacity += io_buf_.Capacity();
Expand Down Expand Up @@ -989,7 +996,7 @@ void Connection::ConnectionFlow() {
VLOG(1) << "Error parser status " << parser_error_;

if (redis_parser_) {
SendProtocolError(RedisParser::Result(parser_error_), reply_builder_);
SendProtocolError(RedisParser::Result(parser_error_), reply_builder_.get());
} else {
DCHECK(memcache_parser_);
reply_builder_->SendProtocolError("bad command line format");
Expand Down Expand Up @@ -1092,7 +1099,7 @@ Connection::ParserStatus Connection::ParseRedis() {

auto dispatch_sync = [this, &parse_args, &cmd_vec] {
RespExpr::VecToArgList(parse_args, &cmd_vec);
service_->DispatchCommand(absl::MakeSpan(cmd_vec), reply_builder_, cc_.get());
service_->DispatchCommand(absl::MakeSpan(cmd_vec), reply_builder_.get(), cc_.get());
};
auto dispatch_async = [this, &parse_args, tlh = mi_heap_get_backing()]() -> MessageHandle {
return {FromArgs(std::move(parse_args), tlh)};
Expand Down Expand Up @@ -1137,14 +1144,14 @@ auto Connection::ParseMemcache() -> ParserStatus {
string_view value;

auto dispatch_sync = [this, &cmd, &value] {
service_->DispatchMC(cmd, value, static_cast<MCReplyBuilder*>(reply_builder_), cc_.get());
service_->DispatchMC(cmd, value, static_cast<MCReplyBuilder*>(reply_builder_.get()), cc_.get());
};

auto dispatch_async = [&cmd, &value]() -> MessageHandle {
return {make_unique<MCPipelineMessage>(std::move(cmd), value)};
};

MCReplyBuilder* builder = static_cast<MCReplyBuilder*>(reply_builder_);
MCReplyBuilder* builder = static_cast<MCReplyBuilder*>(reply_builder_.get());

do {
string_view str = ToSV(io_buf_.InputBuffer());
Expand Down Expand Up @@ -1377,7 +1384,7 @@ void Connection::SquashPipeline() {
cc_->async_dispatch = true;

size_t dispatched =
service_->DispatchManyCommands(absl::MakeSpan(squash_cmds), reply_builder_, cc_.get());
service_->DispatchManyCommands(absl::MakeSpan(squash_cmds), reply_builder_.get(), cc_.get());

if (pending_pipeline_cmd_cnt_ == squash_cmds.size()) { // Flush if no new commands appeared
reply_builder_->Flush();
Expand All @@ -1400,7 +1407,7 @@ void Connection::SquashPipeline() {
}

void Connection::ClearPipelinedMessages() {
DispatchOperations dispatch_op{reply_builder_, this};
DispatchOperations dispatch_op{reply_builder_.get(), this};

// Recycle messages even from disconnecting client to keep properly track of memory stats
// As well as to avoid pubsub backpressure leakege.
Expand Down Expand Up @@ -1448,7 +1455,7 @@ std::string Connection::DebugInfo() const {
void Connection::ExecutionFiber() {
ThisFiber::SetName("ExecutionFiber");

DispatchOperations dispatch_op{reply_builder_, this};
DispatchOperations dispatch_op{reply_builder_.get(), this};

size_t squashing_threshold = GetFlag(FLAGS_pipeline_squash);

Expand Down Expand Up @@ -1812,7 +1819,7 @@ Connection::MemoryUsage Connection::GetMemoryUsage() const {
size_t mem = sizeof(*this) + dfly::HeapSize(dispatch_q_) + dfly::HeapSize(name_) +
dfly::HeapSize(tmp_parse_args_) + dfly::HeapSize(tmp_cmd_vec_) +
dfly::HeapSize(memcache_parser_) + dfly::HeapSize(redis_parser_) +
dfly::HeapSize(cc_);
dfly::HeapSize(cc_) + dfly::HeapSize(reply_builder_);

// We add a hardcoded 9k value to accomodate for the part of the Fiber stack that is in use.
// The allocated stack is actually larger (~130k), but only a small fraction of that (9k
Expand Down
8 changes: 1 addition & 7 deletions src/facade/dragonfly_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,6 @@ class Connection : public util::Connection {

bool IsMain() const;

Protocol protocol() const {
return protocol_;
}

void SetName(std::string name);

void SetLibName(std::string name);
Expand Down Expand Up @@ -404,9 +400,7 @@ class Connection : public util::Connection {
Protocol protocol_;
ConnectionStats* stats_ = nullptr;

// cc_->reply_builder may change during the lifetime of the connection, due to injections.
// This is a pointer to the original, socket based reply builder that never changes.
SinkReplyBuilder* reply_builder_ = nullptr;
std::unique_ptr<SinkReplyBuilder> reply_builder_;
util::HttpListenerBase* http_listener_;
SSL_CTX* ssl_ctx_;

Expand Down
4 changes: 4 additions & 0 deletions src/facade/dragonfly_listener.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ class Listener : public util::ListenerInterface {
bool IsPrivilegedInterface() const;
bool IsMainInterface() const;

Protocol protocol() const {
return protocol_;
}

private:
util::Connection* NewConnection(ProactorBase* proactor) final;
ProactorBase* PickConnectionProactor(util::FiberSocketBase* sock) final;
Expand Down
17 changes: 13 additions & 4 deletions src/facade/facade.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,6 @@ ostream& operator<<(ostream& os, facade::CmdArgList ras) {
return os;
}

ostream& operator<<(ostream& os, facade::Protocol p) {
return os << int(p);
}

ostream& operator<<(ostream& os, const facade::RespExpr& e) {
using facade::RespExpr;
using facade::ToSV;
Expand Down Expand Up @@ -213,4 +209,17 @@ ostream& operator<<(ostream& os, facade::RespSpan ras) {
return os;
}

ostream& operator<<(ostream& os, facade::Protocol p) {
switch (p) {
case facade::Protocol::REDIS:
os << "REDIS";
break;
case facade::Protocol::MEMCACHE:
os << "MEMCACHE";
break;
}

return os;
}

} // namespace std
2 changes: 1 addition & 1 deletion src/facade/facade_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ constexpr size_t kSanitizerOverhead = 0u;
#endif
#endif

enum class Protocol : uint8_t { NONE = 0, MEMCACHE = 1, REDIS = 2 };
enum class Protocol : uint8_t { MEMCACHE = 1, REDIS = 2 };

using MutableSlice = std::string_view;
using CmdArgList = absl::Span<const std::string_view>;
Expand Down
4 changes: 2 additions & 2 deletions src/facade/ok_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ class OkService : public ServiceInterface {
builder->SendError("");
}

ConnectionContext* CreateContext(util::FiberSocketBase* peer, Connection* owner) final {
return new ConnectionContext{peer, owner};
ConnectionContext* CreateContext(Connection* owner) final {
return new ConnectionContext{owner};
}
};

Expand Down
4 changes: 1 addition & 3 deletions src/facade/reply_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,7 @@ class SinkReplyBuilder {
}

public: // High level interface
virtual Protocol GetProtocol() const {
return Protocol::NONE;
}
virtual Protocol GetProtocol() const = 0;

virtual void SendLong(long val) = 0;
virtual void SendSimpleString(std::string_view str) = 0;
Expand Down
2 changes: 1 addition & 1 deletion src/facade/service_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class ServiceInterface {
virtual void DispatchMC(const MemcacheParser::Command& cmd, std::string_view value,
MCReplyBuilder* builder, ConnectionContext* cntx) = 0;

virtual ConnectionContext* CreateContext(util::FiberSocketBase* peer, Connection* owner) = 0;
virtual ConnectionContext* CreateContext(Connection* owner) = 0;

virtual void ConfigureHttpHandlers(util::HttpListenerBase* base, bool is_privileged) {
}
Expand Down
6 changes: 3 additions & 3 deletions src/server/acl/acl_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "base/logging.h"
#include "core/overloaded.h"
#include "facade/dragonfly_connection.h"
#include "facade/dragonfly_listener.h"
#include "facade/facade_types.h"
#include "io/file.h"
#include "io/file_util.h"
Expand Down Expand Up @@ -102,14 +103,13 @@ void AclFamily::StreamUpdatesToAllProactorConnections(const std::string& user,
auto update_cb = [&]([[maybe_unused]] size_t id, util::Connection* conn) {
DCHECK(conn);
auto connection = static_cast<facade::Connection*>(conn);
if (connection->protocol() == facade::Protocol::REDIS && !connection->IsHttp() &&
connection->cntx()) {
if (!connection->IsHttp() && connection->cntx()) {
connection->SendAclUpdateAsync(
facade::Connection::AclUpdateMessage{user, update_commands, update_keys, update_pub_sub});
}
};

if (main_listener_) {
if (main_listener_ && main_listener_->protocol() == facade::Protocol::REDIS) {
main_listener_->TraverseConnections(update_cb);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/server/acl/acl_family.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include <vector>

#include "absl/container/flat_hash_set.h"
#include "facade/dragonfly_listener.h"
#include "facade/facade_types.h"
#include "helio/util/proactor_pool.h"
#include "server/acl/acl_commands_def.h"
Expand All @@ -20,6 +19,7 @@

namespace facade {
class SinkReplyBuilder;
class Listener;
} // namespace facade

namespace dfly {
Expand Down
7 changes: 3 additions & 4 deletions src/server/conn_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,8 @@ const CommandId* StoredCmd::Cid() const {
return cid_;
}

ConnectionContext::ConnectionContext(::io::Sink* stream, facade::Connection* owner,
acl::UserCredentials cred)
: facade::ConnectionContext(stream, owner) {
ConnectionContext::ConnectionContext(facade::Connection* owner, acl::UserCredentials cred)
: facade::ConnectionContext(owner) {
if (owner) {
skip_acl_validation = owner->IsPrivileged();
}
Expand All @@ -117,7 +116,7 @@ ConnectionContext::ConnectionContext(::io::Sink* stream, facade::Connection* own
}

ConnectionContext::ConnectionContext(const ConnectionContext* owner, Transaction* tx)
: facade::ConnectionContext(nullptr, nullptr), transaction{tx} {
: facade::ConnectionContext(nullptr), transaction{tx} {
if (owner) {
acl_commands = owner->acl_commands;
keys = owner->keys;
Expand Down
Loading
Loading